# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module for constructing RNN Cells."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import math

from tensorflow.contrib.compiler import jit
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_impl  # pylint: disable=unused-import
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import partitioned_variables  # pylint: disable=unused-import
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest


def _get_concat_variable(name, shape, dtype, num_shards):
  """Get a sharded variable concatenated into one tensor."""
  sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
  if len(sharded_variable) == 1:
    return sharded_variable[0]

  concat_name = name + "/concat"
  concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
  for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
    if value.name == concat_full_name:
      return value

  concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, concat_variable)
  return concat_variable


def _get_sharded_variable(name, shape, dtype, num_shards):
  """Get a list of sharded variables with the given dtype."""
  if num_shards > shape[0]:
    raise ValueError("Too many shards: shape=%s, num_shards=%d" % (shape,
                                                                   num_shards))
  unit_shard_size = int(math.floor(shape[0] / num_shards))
  remaining_rows = shape[0] - unit_shard_size * num_shards

  shards = []
  for i in range(num_shards):
    current_size = unit_shard_size
    if i < remaining_rows:
      current_size += 1
    shards.append(
        vs.get_variable(
            name + "_%d" % i, [current_size] + shape[1:], dtype=dtype))
  return shards


def _norm(g, b, inp, scope):
  shape = inp.get_shape()[-1:]
  gamma_init = init_ops.constant_initializer(g)
  beta_init = init_ops.constant_initializer(b)
  with vs.variable_scope(scope):
    # Initialize beta and gamma for use by layer_norm.
    vs.get_variable("gamma", shape=shape, initializer=gamma_init)
    vs.get_variable("beta", shape=shape, initializer=beta_init)
  normalized = layers.layer_norm(inp, reuse=True, scope=scope)
  return normalized


class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
  """Long short-term memory unit (LSTM) recurrent network cell.

  The default non-peephole implementation is based on:

    http://www.bioinf.jku.at/publications/older/2604.pdf

  S. Hochreiter and J. Schmidhuber.
  "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.

  The peephole implementation is based on:

    https://research.google.com/pubs/archive/43905.pdf

  Hasim Sak, Andrew Senior, and Francoise Beaufays.
  "Long short-term memory recurrent neural network architectures for
   large scale acoustic modeling." INTERSPEECH, 2014.

  The coupling of input and forget gate is based on:

    http://arxiv.org/pdf/1503.04069.pdf

  Greff et al. "LSTM: A Search Space Odyssey"

  The class uses optional peep-hole connections, and an optional projection
  layer.
  Layer normalization implementation is based on:

    https://arxiv.org/abs/1607.06450.

  "Layer Normalization"
  Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton

  and is applied before the internal nonlinearities.

  """

  def __init__(self,
               num_units,
               use_peepholes=False,
               initializer=None,
               num_proj=None,
               proj_clip=None,
               num_unit_shards=1,
               num_proj_shards=1,
               forget_bias=1.0,
               state_is_tuple=True,
               activation=math_ops.tanh,
               reuse=None,
               layer_norm=False,
               norm_gain=1.0,
               norm_shift=0.0):
    """Initialize the parameters for an LSTM cell.

    Args:
      num_units: int, The number of units in the LSTM cell
      use_peepholes: bool, set True to enable diagonal/peephole connections.
      initializer: (optional) The initializer to use for the weight and
        projection matrices.
      num_proj: (optional) int, The output dimensionality for the projection
        matrices.  If None, no projection is performed.
      proj_clip: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
      provided, then the projected values are clipped elementwise to within
      `[-proj_clip, proj_clip]`.
      num_unit_shards: How to split the weight matrix.  If >1, the weight
        matrix is stored across num_unit_shards.
      num_proj_shards: How to split the projection matrix.  If >1, the
        projection matrix is stored across num_proj_shards.
      forget_bias: Biases of the forget gate are initialized by default to 1
        in order to reduce the scale of forgetting at the beginning of
        the training.
      state_is_tuple: If True, accepted and returned states are 2-tuples of
        the `c_state` and `m_state`.  By default (False), they are concatenated
        along the column axis.  This default behavior will soon be deprecated.
      activation: Activation function of the inner states.
      reuse: (optional) Python boolean describing whether to reuse variables
        in an existing scope.  If not `True`, and the existing scope already has
        the given variables, an error is raised.
      layer_norm: If `True`, layer normalization will be applied.
      norm_gain: float, The layer normalization gain initial value. If
        `layer_norm` has been set to `False`, this argument will be ignored.
      norm_shift: float, The layer normalization shift initial value. If
        `layer_norm` has been set to `False`, this argument will be ignored.
    """
    super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse)
    if not state_is_tuple:
      logging.warn("%s: Using a concatenated state is slower and will soon be "
                   "deprecated.  Use state_is_tuple=True.", self)
    self._num_units = num_units
    self._use_peepholes = use_peepholes
    self._initializer = initializer
    self._num_proj = num_proj
    self._proj_clip = proj_clip
    self._num_unit_shards = num_unit_shards
    self._num_proj_shards = num_proj_shards
    self._forget_bias = forget_bias
    self._state_is_tuple = state_is_tuple
    self._activation = activation
    self._reuse = reuse
    self._layer_norm = layer_norm
    self._norm_gain = norm_gain
    self._norm_shift = norm_shift

    if num_proj:
      self._state_size = (
          rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
          if state_is_tuple else num_units + num_proj)
      self._output_size = num_proj
    else:
      self._state_size = (
          rnn_cell_impl.LSTMStateTuple(num_units, num_units)
          if state_is_tuple else 2 * num_units)
      self._output_size = num_units

  @property
  def state_size(self):
    return self._state_size

  @property
  def output_size(self):
    return self._output_size

  def call(self, inputs, state):
    """Run one step of LSTM.

    Args:
      inputs: input Tensor, 2D, batch x num_units.
      state: if `state_is_tuple` is False, this must be a state Tensor,
        `2-D, batch x state_size`.  If `state_is_tuple` is True, this must be a
        tuple of state Tensors, both `2-D`, with column sizes `c_state` and
        `m_state`.

    Returns:
      A tuple containing:
      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
        LSTM after reading `inputs` when previous state was `state`.
        Here output_dim is:
           num_proj if num_proj was set,
           num_units otherwise.
      - Tensor(s) representing the new state of LSTM after reading `inputs` when
        the previous state was `state`.  Same type and shape(s) as `state`.

    Raises:
      ValueError: If input size cannot be inferred from inputs via
        static shape inference.
    """
    sigmoid = math_ops.sigmoid

    num_proj = self._num_units if self._num_proj is None else self._num_proj

    if self._state_is_tuple:
      (c_prev, m_prev) = state
    else:
      c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
      m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])

    dtype = inputs.dtype
    input_size = inputs.get_shape().with_rank(2)[1]
    if input_size.value is None:
      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
    concat_w = _get_concat_variable(
        "W", [input_size.value + num_proj, 3 * self._num_units], dtype,
        self._num_unit_shards)

    b = vs.get_variable(
        "B",
        shape=[3 * self._num_units],
        initializer=init_ops.zeros_initializer(),
        dtype=dtype)

    # j = new_input, f = forget_gate, o = output_gate
    cell_inputs = array_ops.concat([inputs, m_prev], 1)
    lstm_matrix = math_ops.matmul(cell_inputs, concat_w)

    # If layer nomalization is applied, do not add bias
    if not self._layer_norm:
      lstm_matrix = nn_ops.bias_add(lstm_matrix, b)

    j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1)

    # Apply layer normalization
    if self._layer_norm:
      j = _norm(self._norm_gain, self._norm_shift, j, "transform")
      f = _norm(self._norm_gain, self._norm_shift, f, "forget")
      o = _norm(self._norm_gain, self._norm_shift, o, "output")

    # Diagonal connections
    if self._use_peepholes:
      w_f_diag = vs.get_variable(
          "W_F_diag", shape=[self._num_units], dtype=dtype)
      w_o_diag = vs.get_variable(
          "W_O_diag", shape=[self._num_units], dtype=dtype)

    if self._use_peepholes:
      f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev)
    else:
      f_act = sigmoid(f + self._forget_bias)
    c = (f_act * c_prev + (1 - f_act) * self._activation(j))

    # Apply layer normalization
    if self._layer_norm:
      c = _norm(self._norm_gain, self._norm_shift, c, "state")

    if self._use_peepholes:
      m = sigmoid(o + w_o_diag * c) * self._activation(c)
    else:
      m = sigmoid(o) * self._activation(c)

    if self._num_proj is not None:
      concat_w_proj = _get_concat_variable("W_P",
                                           [self._num_units, self._num_proj],
                                           dtype, self._num_proj_shards)

      m = math_ops.matmul(m, concat_w_proj)
      if self._proj_clip is not None:
        # pylint: disable=invalid-unary-operand-type
        m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
        # pylint: enable=invalid-unary-operand-type

    new_state = (
        rnn_cell_impl.LSTMStateTuple(c, m)
        if self._state_is_tuple else array_ops.concat([c, m], 1))
    return m, new_state


class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
  """Time-Frequency Long short-term memory unit (LSTM) recurrent network cell.

  This implementation is based on:

    Tara N. Sainath and Bo Li
    "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures
    for LVCSR Tasks." submitted to INTERSPEECH, 2016.

  It uses peep-hole connections and optional cell clipping.
  """

  def __init__(self,
               num_units,
               use_peepholes=False,
               cell_clip=None,
               initializer=None,
               num_unit_shards=1,
               forget_bias=1.0,
               feature_size=None,
               frequency_skip=1,
               reuse=None):
    """Initialize the parameters for an LSTM cell.

    Args:
      num_units: int, The number of units in the LSTM cell
      use_peepholes: bool, set True to enable diagonal/peephole connections.
      cell_clip: (optional) A float value, if provided the cell state is clipped
        by this value prior to the cell output activation.
      initializer: (optional) The initializer to use for the weight and
        projection matrices.
      num_unit_shards: int, How to split the weight matrix.  If >1, the weight
        matrix is stored across num_unit_shards.
      forget_bias: float, Biases of the forget gate are initialized by default
        to 1 in order to reduce the scale of forgetting at the beginning
        of the training.
      feature_size: int, The size of the input feature the LSTM spans over.
      frequency_skip: int, The amount the LSTM filter is shifted by in
        frequency.
      reuse: (optional) Python boolean describing whether to reuse variables
        in an existing scope.  If not `True`, and the existing scope already has
        the given variables, an error is raised.
    """
    super(TimeFreqLSTMCell, self).__init__(_reuse=reuse)
    self._num_units = num_units
    self._use_peepholes = use_peepholes
    self._cell_clip = cell_clip
    self._initializer = initializer
    self._num_unit_shards = num_unit_shards
    self._forget_bias = forget_bias
    self._feature_size = feature_size
    self._frequency_skip = frequency_skip
    self._state_size = 2 * num_units
    self._output_size = num_units
    self._reuse = reuse

  @property
  def output_size(self):
    return self._output_size

  @property
  def state_size(self):
    return self._state_size

  def call(self, inputs, state):
    """Run one step of LSTM.

    Args:
      inputs: input Tensor, 2D, batch x num_units.
      state: state Tensor, 2D, batch x state_size.

    Returns:
      A tuple containing:
      - A 2D, batch x output_dim, Tensor representing the output of the LSTM
        after reading "inputs" when previous state was "state".
        Here output_dim is num_units.
      - A 2D, batch x state_size, Tensor representing the new state of LSTM
        after reading "inputs" when previous state was "state".
    Raises:
      ValueError: if an input_size was specified and the provided inputs have
        a different dimension.
    """
    sigmoid = math_ops.sigmoid
    tanh = math_ops.tanh

    freq_inputs = self._make_tf_features(inputs)
    dtype = inputs.dtype
    actual_input_size = freq_inputs[0].get_shape().as_list()[1]

    concat_w = _get_concat_variable(
        "W", [actual_input_size + 2 * self._num_units, 4 * self._num_units],
        dtype, self._num_unit_shards)

    b = vs.get_variable(
        "B",
        shape=[4 * self._num_units],
        initializer=init_ops.zeros_initializer(),
        dtype=dtype)

    # Diagonal connections
    if self._use_peepholes:
      w_f_diag = vs.get_variable(
          "W_F_diag", shape=[self._num_units], dtype=dtype)
      w_i_diag = vs.get_variable(
          "W_I_diag", shape=[self._num_units], dtype=dtype)
      w_o_diag = vs.get_variable(
          "W_O_diag", shape=[self._num_units], dtype=dtype)

    # initialize the first freq state to be zero
    m_prev_freq = array_ops.zeros(
        [inputs.shape[0].value or inputs.get_shape()[0], self._num_units],
        dtype)
    for fq in range(len(freq_inputs)):
      c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units],
                               [-1, self._num_units])
      m_prev = array_ops.slice(state, [0, (2 * fq + 1) * self._num_units],
                               [-1, self._num_units])
      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
      cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq], 1)
      lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
      i, j, f, o = array_ops.split(
          value=lstm_matrix, num_or_size_splits=4, axis=1)

      if self._use_peepholes:
        c = (
            sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
            sigmoid(i + w_i_diag * c_prev) * tanh(j))
      else:
        c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j))

      if self._cell_clip is not None:
        # pylint: disable=invalid-unary-operand-type
        c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
        # pylint: enable=invalid-unary-operand-type

      if self._use_peepholes:
        m = sigmoid(o + w_o_diag * c) * tanh(c)
      else:
        m = sigmoid(o) * tanh(c)
      m_prev_freq = m
      if fq == 0:
        state_out = array_ops.concat([c, m], 1)
        m_out = m
      else:
        state_out = array_ops.concat([state_out, c, m], 1)
        m_out = array_ops.concat([m_out, m], 1)
    return m_out, state_out

  def _make_tf_features(self, input_feat):
    """Make the frequency features.

    Args:
      input_feat: input Tensor, 2D, batch x num_units.

    Returns:
      A list of frequency features, with each element containing:
      - A 2D, batch x output_dim, Tensor representing the time-frequency feature
        for that frequency index. Here output_dim is feature_size.
    Raises:
      ValueError: if input_size cannot be inferred from static shape inference.
    """
    input_size = input_feat.get_shape().with_rank(2)[-1].value
    if input_size is None:
      raise ValueError("Cannot infer input_size from static shape inference.")
    num_feats = int(
        (input_size - self._feature_size) / (self._frequency_skip)) + 1
    freq_inputs = []
    for f in range(num_feats):
      cur_input = array_ops.slice(input_feat, [0, f * self._frequency_skip],
                                  [-1, self._feature_size])
      freq_inputs.append(cur_input)
    return freq_inputs


class GridLSTMCell(rnn_cell_impl.RNNCell):
  """Grid Long short-term memory unit (LSTM) recurrent network cell.

  The default is based on:
    Nal Kalchbrenner, Ivo Danihelka and Alex Graves
    "Grid Long Short-Term Memory," Proc. ICLR 2016.
    http://arxiv.org/abs/1507.01526

  When peephole connections are used, the implementation is based on:
    Tara N. Sainath and Bo Li
    "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures
    for LVCSR Tasks." submitted to INTERSPEECH, 2016.

  The code uses optional peephole connections, shared_weights and cell clipping.
  """

  def __init__(self,
               num_units,
               use_peepholes=False,
               share_time_frequency_weights=False,
               cell_clip=None,
               initializer=None,
               num_unit_shards=1,
               forget_bias=1.0,
               feature_size=None,
               frequency_skip=None,
               num_frequency_blocks=None,
               start_freqindex_list=None,
               end_freqindex_list=None,
               couple_input_forget_gates=False,
               state_is_tuple=True,
               reuse=None):
    """Initialize the parameters for an LSTM cell.

    Args:
      num_units: int, The number of units in the LSTM cell
      use_peepholes: (optional) bool, default False. Set True to enable
        diagonal/peephole connections.
      share_time_frequency_weights: (optional) bool, default False. Set True to
        enable shared cell weights between time and frequency LSTMs.
      cell_clip: (optional) A float value, default None, if provided the cell
        state is clipped by this value prior to the cell output activation.
      initializer: (optional) The initializer to use for the weight and
        projection matrices, default None.
      num_unit_shards: (optional) int, default 1, How to split the weight
        matrix. If > 1, the weight matrix is stored across num_unit_shards.
      forget_bias: (optional) float, default 1.0, The initial bias of the
        forget gates, used to reduce the scale of forgetting at the beginning
        of the training.
      feature_size: (optional) int, default None, The size of the input feature
        the LSTM spans over.
      frequency_skip: (optional) int, default None, The amount the LSTM filter
        is shifted by in frequency.
      num_frequency_blocks: [required] A list of frequency blocks needed to
        cover the whole input feature splitting defined by start_freqindex_list
        and end_freqindex_list.
      start_freqindex_list: [optional], list of ints, default None,  The
        starting frequency index for each frequency block.
      end_freqindex_list: [optional], list of ints, default None. The ending
        frequency index for each frequency block.
      couple_input_forget_gates: (optional) bool, default False, Whether to
        couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce
        model parameters and computation cost.
      state_is_tuple: If True, accepted and returned states are 2-tuples of
        the `c_state` and `m_state`.  By default (False), they are concatenated
        along the column axis.  This default behavior will soon be deprecated.
      reuse: (optional) Python boolean describing whether to reuse variables
        in an existing scope.  If not `True`, and the existing scope already has
        the given variables, an error is raised.
    Raises:
      ValueError: if the num_frequency_blocks list is not specified
    """
    super(GridLSTMCell, self).__init__(_reuse=reuse)
    if not state_is_tuple:
      logging.warn("%s: Using a concatenated state is slower and will soon be "
                   "deprecated.  Use state_is_tuple=True.", self)
    self._num_units = num_units
    self._use_peepholes = use_peepholes
    self._share_time_frequency_weights = share_time_frequency_weights
    self._couple_input_forget_gates = couple_input_forget_gates
    self._state_is_tuple = state_is_tuple
    self._cell_clip = cell_clip
    self._initializer = initializer
    self._num_unit_shards = num_unit_shards
    self._forget_bias = forget_bias
    self._feature_size = feature_size
    self._frequency_skip = frequency_skip
    self._start_freqindex_list = start_freqindex_list
    self._end_freqindex_list = end_freqindex_list
    self._num_frequency_blocks = num_frequency_blocks
    self._total_blocks = 0
    self._reuse = reuse
    if self._num_frequency_blocks is None:
      raise ValueError("Must specify num_frequency_blocks")

    for block_index in range(len(self._num_frequency_blocks)):
      self._total_blocks += int(self._num_frequency_blocks[block_index])
    if state_is_tuple:
      state_names = ""
      for block_index in range(len(self._num_frequency_blocks)):
        for freq_index in range(self._num_frequency_blocks[block_index]):
          name_prefix = "state_f%02d_b%02d" % (freq_index, block_index)
          state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
      self._state_tuple_type = collections.namedtuple("GridLSTMStateTuple",
                                                      state_names.strip(","))
      self._state_size = self._state_tuple_type(*(
          [num_units, num_units] * self._total_blocks))
    else:
      self._state_tuple_type = None
      self._state_size = num_units * self._total_blocks * 2
    self._output_size = num_units * self._total_blocks * 2

  @property
  def output_size(self):
    return self._output_size

  @property
  def state_size(self):
    return self._state_size

  @property
  def state_tuple_type(self):
    return self._state_tuple_type

  def call(self, inputs, state):
    """Run one step of LSTM.

    Args:
      inputs: input Tensor, 2D, [batch, feature_size].
      state: Tensor or tuple of Tensors, 2D, [batch, state_size], depends on the
        flag self._state_is_tuple.

    Returns:
      A tuple containing:
      - A 2D, [batch, output_dim], Tensor representing the output of the LSTM
        after reading "inputs" when previous state was "state".
        Here output_dim is num_units.
      - A 2D, [batch, state_size], Tensor representing the new state of LSTM
        after reading "inputs" when previous state was "state".
    Raises:
      ValueError: if an input_size was specified and the provided inputs have
        a different dimension.
    """
    batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
    freq_inputs = self._make_tf_features(inputs)
    m_out_lst = []
    state_out_lst = []
    for block in range(len(freq_inputs)):
      m_out_lst_current, state_out_lst_current = self._compute(
          freq_inputs[block],
          block,
          state,
          batch_size,
          state_is_tuple=self._state_is_tuple)
      m_out_lst.extend(m_out_lst_current)
      state_out_lst.extend(state_out_lst_current)
    if self._state_is_tuple:
      state_out = self._state_tuple_type(*state_out_lst)
    else:
      state_out = array_ops.concat(state_out_lst, 1)
    m_out = array_ops.concat(m_out_lst, 1)
    return m_out, state_out

  def _compute(self,
               freq_inputs,
               block,
               state,
               batch_size,
               state_prefix="state",
               state_is_tuple=True):
    """Run the actual computation of one step LSTM.

    Args:
      freq_inputs: list of Tensors, 2D, [batch, feature_size].
      block: int, current frequency block index to process.
      state: Tensor or tuple of Tensors, 2D, [batch, state_size], it depends on
        the flag state_is_tuple.
      batch_size: int32, batch size.
      state_prefix: (optional) string, name prefix for states, defaults to
        "state".
      state_is_tuple: boolean, indicates whether the state is a tuple or Tensor.

    Returns:
      A tuple, containing:
      - A list of [batch, output_dim] Tensors, representing the output of the
        LSTM given the inputs and state.
      - A list of [batch, state_size] Tensors, representing the LSTM state
        values given the inputs and previous state.
    """
    sigmoid = math_ops.sigmoid
    tanh = math_ops.tanh
    num_gates = 3 if self._couple_input_forget_gates else 4
    dtype = freq_inputs[0].dtype
    actual_input_size = freq_inputs[0].get_shape().as_list()[1]

    concat_w_f = _get_concat_variable(
        "W_f_%d" % block,
        [actual_input_size + 2 * self._num_units, num_gates * self._num_units],
        dtype, self._num_unit_shards)
    b_f = vs.get_variable(
        "B_f_%d" % block,
        shape=[num_gates * self._num_units],
        initializer=init_ops.zeros_initializer(),
        dtype=dtype)
    if not self._share_time_frequency_weights:
      concat_w_t = _get_concat_variable("W_t_%d" % block, [
          actual_input_size + 2 * self._num_units, num_gates * self._num_units
      ], dtype, self._num_unit_shards)
      b_t = vs.get_variable(
          "B_t_%d" % block,
          shape=[num_gates * self._num_units],
          initializer=init_ops.zeros_initializer(),
          dtype=dtype)

    if self._use_peepholes:
      # Diagonal connections
      if not self._couple_input_forget_gates:
        w_f_diag_freqf = vs.get_variable(
            "W_F_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
        w_f_diag_freqt = vs.get_variable(
            "W_F_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
      w_i_diag_freqf = vs.get_variable(
          "W_I_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
      w_i_diag_freqt = vs.get_variable(
          "W_I_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
      w_o_diag_freqf = vs.get_variable(
          "W_O_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
      w_o_diag_freqt = vs.get_variable(
          "W_O_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
      if not self._share_time_frequency_weights:
        if not self._couple_input_forget_gates:
          w_f_diag_timef = vs.get_variable(
              "W_F_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
          w_f_diag_timet = vs.get_variable(
              "W_F_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
        w_i_diag_timef = vs.get_variable(
            "W_I_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
        w_i_diag_timet = vs.get_variable(
            "W_I_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
        w_o_diag_timef = vs.get_variable(
            "W_O_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
        w_o_diag_timet = vs.get_variable(
            "W_O_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)

    # initialize the first freq state to be zero
    m_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype)
    c_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype)
    for freq_index in range(len(freq_inputs)):
      if state_is_tuple:
        name_prefix = "%s_f%02d_b%02d" % (state_prefix, freq_index, block)
        c_prev_time = getattr(state, name_prefix + "_c")
        m_prev_time = getattr(state, name_prefix + "_m")
      else:
        c_prev_time = array_ops.slice(
            state, [0, 2 * freq_index * self._num_units], [-1, self._num_units])
        m_prev_time = array_ops.slice(
            state, [0, (2 * freq_index + 1) * self._num_units],
            [-1, self._num_units])

      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
      cell_inputs = array_ops.concat(
          [freq_inputs[freq_index], m_prev_time, m_prev_freq], 1)

      # F-LSTM
      lstm_matrix_freq = nn_ops.bias_add(
          math_ops.matmul(cell_inputs, concat_w_f), b_f)
      if self._couple_input_forget_gates:
        i_freq, j_freq, o_freq = array_ops.split(
            value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1)
        f_freq = None
      else:
        i_freq, j_freq, f_freq, o_freq = array_ops.split(
            value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1)
      # T-LSTM
      if self._share_time_frequency_weights:
        i_time = i_freq
        j_time = j_freq
        f_time = f_freq
        o_time = o_freq
      else:
        lstm_matrix_time = nn_ops.bias_add(
            math_ops.matmul(cell_inputs, concat_w_t), b_t)
        if self._couple_input_forget_gates:
          i_time, j_time, o_time = array_ops.split(
              value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1)
          f_time = None
        else:
          i_time, j_time, f_time, o_time = array_ops.split(
              value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1)

      # F-LSTM c_freq
      # input gate activations
      if self._use_peepholes:
        i_freq_g = sigmoid(i_freq + w_i_diag_freqf * c_prev_freq +
                           w_i_diag_freqt * c_prev_time)
      else:
        i_freq_g = sigmoid(i_freq)
      # forget gate activations
      if self._couple_input_forget_gates:
        f_freq_g = 1.0 - i_freq_g
      else:
        if self._use_peepholes:
          f_freq_g = sigmoid(f_freq + self._forget_bias + w_f_diag_freqf *
                             c_prev_freq + w_f_diag_freqt * c_prev_time)
        else:
          f_freq_g = sigmoid(f_freq + self._forget_bias)
      # cell state
      c_freq = f_freq_g * c_prev_freq + i_freq_g * tanh(j_freq)
      if self._cell_clip is not None:
        # pylint: disable=invalid-unary-operand-type
        c_freq = clip_ops.clip_by_value(c_freq, -self._cell_clip,
                                        self._cell_clip)
        # pylint: enable=invalid-unary-operand-type

      # T-LSTM c_freq
      # input gate activations
      if self._use_peepholes:
        if self._share_time_frequency_weights:
          i_time_g = sigmoid(i_time + w_i_diag_freqf * c_prev_freq +
                             w_i_diag_freqt * c_prev_time)
        else:
          i_time_g = sigmoid(i_time + w_i_diag_timef * c_prev_freq +
                             w_i_diag_timet * c_prev_time)
      else:
        i_time_g = sigmoid(i_time)
      # forget gate activations
      if self._couple_input_forget_gates:
        f_time_g = 1.0 - i_time_g
      else:
        if self._use_peepholes:
          if self._share_time_frequency_weights:
            f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_freqf *
                               c_prev_freq + w_f_diag_freqt * c_prev_time)
          else:
            f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_timef *
                               c_prev_freq + w_f_diag_timet * c_prev_time)
        else:
          f_time_g = sigmoid(f_time + self._forget_bias)
      # cell state
      c_time = f_time_g * c_prev_time + i_time_g * tanh(j_time)
      if self._cell_clip is not None:
        # pylint: disable=invalid-unary-operand-type
        c_time = clip_ops.clip_by_value(c_time, -self._cell_clip,
                                        self._cell_clip)
        # pylint: enable=invalid-unary-operand-type

      # F-LSTM m_freq
      if self._use_peepholes:
        m_freq = sigmoid(o_freq + w_o_diag_freqf * c_freq +
                         w_o_diag_freqt * c_time) * tanh(c_freq)
      else:
        m_freq = sigmoid(o_freq) * tanh(c_freq)

      # T-LSTM m_time
      if self._use_peepholes:
        if self._share_time_frequency_weights:
          m_time = sigmoid(o_time + w_o_diag_freqf * c_freq +
                           w_o_diag_freqt * c_time) * tanh(c_time)
        else:
          m_time = sigmoid(o_time + w_o_diag_timef * c_freq +
                           w_o_diag_timet * c_time) * tanh(c_time)
      else:
        m_time = sigmoid(o_time) * tanh(c_time)

      m_prev_freq = m_freq
      c_prev_freq = c_freq
      # Concatenate the outputs for T-LSTM and F-LSTM for each shift
      if freq_index == 0:
        state_out_lst = [c_time, m_time]
        m_out_lst = [m_time, m_freq]
      else:
        state_out_lst.extend([c_time, m_time])
        m_out_lst.extend([m_time, m_freq])

    return m_out_lst, state_out_lst

  def _make_tf_features(self, input_feat, slice_offset=0):
    """Make the frequency features.

    Args:
      input_feat: input Tensor, 2D, [batch, num_units].
      slice_offset: (optional) Python int, default 0, the slicing offset is only
        used for the backward processing in the BidirectionalGridLSTMCell. It
        specifies a different starting point instead of always 0 to enable the
        forward and backward processing look at different frequency blocks.

    Returns:
      A list of frequency features, with each element containing:
      - A 2D, [batch, output_dim], Tensor representing the time-frequency
        feature for that frequency index. Here output_dim is feature_size.
    Raises:
      ValueError: if input_size cannot be inferred from static shape inference.
    """
    input_size = input_feat.get_shape().with_rank(2)[-1].value
    if input_size is None:
      raise ValueError("Cannot infer input_size from static shape inference.")
    if slice_offset > 0:
      # Padding to the end
      inputs = array_ops.pad(input_feat,
                             array_ops.constant(
                                 [0, 0, 0, slice_offset],
                                 shape=[2, 2],
                                 dtype=dtypes.int32), "CONSTANT")
    elif slice_offset < 0:
      # Padding to the front
      inputs = array_ops.pad(input_feat,
                             array_ops.constant(
                                 [0, 0, -slice_offset, 0],
                                 shape=[2, 2],
                                 dtype=dtypes.int32), "CONSTANT")
      slice_offset = 0
    else:
      inputs = input_feat
    freq_inputs = []
    if not self._start_freqindex_list:
      if len(self._num_frequency_blocks) != 1:
        raise ValueError("Length of num_frequency_blocks"
                         " is not 1, but instead is %d",
                         len(self._num_frequency_blocks))
      num_feats = int(
          (input_size - self._feature_size) / (self._frequency_skip)) + 1
      if num_feats != self._num_frequency_blocks[0]:
        raise ValueError(
            "Invalid num_frequency_blocks, requires %d but gets %d, please"
            " check the input size and filter config are correct." %
            (self._num_frequency_blocks[0], num_feats))
      block_inputs = []
      for f in range(num_feats):
        cur_input = array_ops.slice(
            inputs, [0, slice_offset + f * self._frequency_skip],
            [-1, self._feature_size])
        block_inputs.append(cur_input)
      freq_inputs.append(block_inputs)
    else:
      if len(self._start_freqindex_list) != len(self._end_freqindex_list):
        raise ValueError("Length of start and end freqindex_list"
                         " does not match %d %d",
                         len(self._start_freqindex_list),
                         len(self._end_freqindex_list))
      if len(self._num_frequency_blocks) != len(self._start_freqindex_list):
        raise ValueError("Length of num_frequency_blocks"
                         " is not equal to start_freqindex_list %d %d",
                         len(self._num_frequency_blocks),
                         len(self._start_freqindex_list))
      for b in range(len(self._start_freqindex_list)):
        start_index = self._start_freqindex_list[b]
        end_index = self._end_freqindex_list[b]
        cur_size = end_index - start_index
        block_feats = int(
            (cur_size - self._feature_size) / (self._frequency_skip)) + 1
        if block_feats != self._num_frequency_blocks[b]:
          raise ValueError(
              "Invalid num_frequency_blocks, requires %d but gets %d, please"
              " check the input size and filter config are correct." %
              (self._num_frequency_blocks[b], block_feats))
        block_inputs = []
        for f in range(block_feats):
          cur_input = array_ops.slice(
              inputs,
              [0, start_index + slice_offset + f * self._frequency_skip],
              [-1, self._feature_size])
          block_inputs.append(cur_input)
        freq_inputs.append(block_inputs)
    return freq_inputs


class BidirectionalGridLSTMCell(GridLSTMCell):
  """Bidirectional GridLstm cell.

  The bidirection connection is only used in the frequency direction, which
  hence doesn't affect the time direction's real-time processing that is
  required for online recognition systems.
  The current implementation uses different weights for the two directions.
  """

  def __init__(self,
               num_units,
               use_peepholes=False,
               share_time_frequency_weights=False,
               cell_clip=None,
               initializer=None,
               num_unit_shards=1,
               forget_bias=1.0,
               feature_size=None,
               frequency_skip=None,
               num_frequency_blocks=None,
               start_freqindex_list=None,
               end_freqindex_list=None,
               couple_input_forget_gates=False,
               backward_slice_offset=0,
               reuse=None):
    """Initialize the parameters for an LSTM cell.

    Args:
      num_units: int, The number of units in the LSTM cell
      use_peepholes: (optional) bool, default False. Set True to enable
        diagonal/peephole connections.
      share_time_frequency_weights: (optional) bool, default False. Set True to
        enable shared cell weights between time and frequency LSTMs.
      cell_clip: (optional) A float value, default None, if provided the cell
        state is clipped by this value prior to the cell output activation.
      initializer: (optional) The initializer to use for the weight and
        projection matrices, default None.
      num_unit_shards: (optional) int, default 1, How to split the weight
        matrix. If > 1, the weight matrix is stored across num_unit_shards.
      forget_bias: (optional) float, default 1.0, The initial bias of the
        forget gates, used to reduce the scale of forgetting at the beginning
        of the training.
      feature_size: (optional) int, default None, The size of the input feature
        the LSTM spans over.
      frequency_skip: (optional) int, default None, The amount the LSTM filter
        is shifted by in frequency.
      num_frequency_blocks: [required] A list of frequency blocks needed to
        cover the whole input feature splitting defined by start_freqindex_list
        and end_freqindex_list.
      start_freqindex_list: [optional], list of ints, default None,  The
        starting frequency index for each frequency block.
      end_freqindex_list: [optional], list of ints, default None. The ending
        frequency index for each frequency block.
      couple_input_forget_gates: (optional) bool, default False, Whether to
        couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce
        model parameters and computation cost.
      backward_slice_offset: (optional) int32, default 0, the starting offset to
        slice the feature for backward processing.
      reuse: (optional) Python boolean describing whether to reuse variables
        in an existing scope.  If not `True`, and the existing scope already has
        the given variables, an error is raised.
    """
    super(BidirectionalGridLSTMCell, self).__init__(
        num_units, use_peepholes, share_time_frequency_weights, cell_clip,
        initializer, num_unit_shards, forget_bias, feature_size, frequency_skip,
        num_frequency_blocks, start_freqindex_list, end_freqindex_list,
        couple_input_forget_gates, True, reuse)
    self._backward_slice_offset = int(backward_slice_offset)
    state_names = ""
    for direction in ["fwd", "bwd"]:
      for block_index in range(len(self._num_frequency_blocks)):
        for freq_index in range(self._num_frequency_blocks[block_index]):
          name_prefix = "%s_state_f%02d_b%02d" % (direction, freq_index,
                                                  block_index)
          state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
    self._state_tuple_type = collections.namedtuple(
        "BidirectionalGridLSTMStateTuple", state_names.strip(","))
    self._state_size = self._state_tuple_type(*(
        [num_units, num_units] * self._total_blocks * 2))
    self._output_size = 2 * num_units * self._total_blocks * 2

  def call(self, inputs, state):
    """Run one step of LSTM.

    Args:
      inputs: input Tensor, 2D, [batch, num_units].
      state: tuple of Tensors, 2D, [batch, state_size].

    Returns:
      A tuple containing:
      - A 2D, [batch, output_dim], Tensor representing the output of the LSTM
        after reading "inputs" when previous state was "state".
        Here output_dim is num_units.
      - A 2D, [batch, state_size], Tensor representing the new state of LSTM
        after reading "inputs" when previous state was "state".
    Raises:
      ValueError: if an input_size was specified and the provided inputs have
        a different dimension.
    """
    batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
    fwd_inputs = self._make_tf_features(inputs)
    if self._backward_slice_offset:
      bwd_inputs = self._make_tf_features(inputs, self._backward_slice_offset)
    else:
      bwd_inputs = fwd_inputs

    # Forward processing
    with vs.variable_scope("fwd"):
      fwd_m_out_lst = []
      fwd_state_out_lst = []
      for block in range(len(fwd_inputs)):
        fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
            fwd_inputs[block],
            block,
            state,
            batch_size,
            state_prefix="fwd_state",
            state_is_tuple=True)
        fwd_m_out_lst.extend(fwd_m_out_lst_current)
        fwd_state_out_lst.extend(fwd_state_out_lst_current)
    # Backward processing
    bwd_m_out_lst = []
    bwd_state_out_lst = []
    with vs.variable_scope("bwd"):
      for block in range(len(bwd_inputs)):
        # Reverse the blocks
        bwd_inputs_reverse = bwd_inputs[block][::-1]
        bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
            bwd_inputs_reverse,
            block,
            state,
            batch_size,
            state_prefix="bwd_state",
            state_is_tuple=True)
        bwd_m_out_lst.extend(bwd_m_out_lst_current)
        bwd_state_out_lst.extend(bwd_state_out_lst_current)
    state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst))
    # Outputs are always concated as it is never used separately.
    m_out = array_ops.concat(fwd_m_out_lst + bwd_m_out_lst, 1)
    return m_out, state_out


# pylint: disable=protected-access
_Linear = core_rnn_cell._Linear  # pylint: disable=invalid-name

# pylint: enable=protected-access


class AttentionCellWrapper(rnn_cell_impl.RNNCell):
  """Basic attention cell wrapper.

  Implementation based on https://arxiv.org/abs/1409.0473.
  """

  def __init__(self,
               cell,
               attn_length,
               attn_size=None,
               attn_vec_size=None,
               input_size=None,
               state_is_tuple=True,
               reuse=None):
    """Create a cell with attention.

    Args:
      cell: an RNNCell, an attention is added to it.
      attn_length: integer, the size of an attention window.
      attn_size: integer, the size of an attention vector. Equal to
          cell.output_size by default.
      attn_vec_size: integer, the number of convolutional features calculated
          on attention state and a size of the hidden layer built from
          base cell state. Equal attn_size to by default.
      input_size: integer, the size of a hidden linear layer,
          built from inputs and attention. Derived from the input tensor
          by default.
      state_is_tuple: If True, accepted and returned states are n-tuples, where
        `n = len(cells)`.  By default (False), the states are all
        concatenated along the column axis.
      reuse: (optional) Python boolean describing whether to reuse variables
        in an existing scope.  If not `True`, and the existing scope already has
        the given variables, an error is raised.

    Raises:
      TypeError: if cell is not an RNNCell.
      ValueError: if cell returns a state tuple but the flag
          `state_is_tuple` is `False` or if attn_length is zero or less.
    """
    super(AttentionCellWrapper, self).__init__(_reuse=reuse)
    rnn_cell_impl.assert_like_rnncell("cell", cell)
    if nest.is_sequence(cell.state_size) and not state_is_tuple:
      raise ValueError(
          "Cell returns tuple of states, but the flag "
          "state_is_tuple is not set. State size is: %s" % str(cell.state_size))
    if attn_length <= 0:
      raise ValueError(
          "attn_length should be greater than zero, got %s" % str(attn_length))
    if not state_is_tuple:
      logging.warn("%s: Using a concatenated state is slower and will soon be "
                   "deprecated.  Use state_is_tuple=True.", self)
    if attn_size is None:
      attn_size = cell.output_size
    if attn_vec_size is None:
      attn_vec_size = attn_size
    self._state_is_tuple = state_is_tuple
    self._cell = cell
    self._attn_vec_size = attn_vec_size
    self._input_size = input_size
    self._attn_size = attn_size
    self._attn_length = attn_length
    self._reuse = reuse
    self._linear1 = None
    self._linear2 = None
    self._linear3 = None

  @property
  def state_size(self):
    size = (self._cell.state_size, self._attn_size,
            self._attn_size * self._attn_length)
    if self._state_is_tuple:
      return size
    else:
      return sum(list(size))

  @property
  def output_size(self):
    return self._attn_size

  def call(self, inputs, state):
    """Long short-term memory cell with attention (LSTMA)."""
    if self._state_is_tuple:
      state, attns, attn_states = state
    else:
      states = state
      state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
      attns = array_ops.slice(states, [0, self._cell.state_size],
                              [-1, self._attn_size])
      attn_states = array_ops.slice(
          states, [0, self._cell.state_size + self._attn_size],
          [-1, self._attn_size * self._attn_length])
    attn_states = array_ops.reshape(attn_states,
                                    [-1, self._attn_length, self._attn_size])
    input_size = self._input_size
    if input_size is None:
      input_size = inputs.get_shape().as_list()[1]
    if self._linear1 is None:
      self._linear1 = _Linear([inputs, attns], input_size, True)
    inputs = self._linear1([inputs, attns])
    cell_output, new_state = self._cell(inputs, state)
    if self._state_is_tuple:
      new_state_cat = array_ops.concat(nest.flatten(new_state), 1)
    else:
      new_state_cat = new_state
    new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
    with vs.variable_scope("attn_output_projection"):
      if self._linear2 is None:
        self._linear2 = _Linear([cell_output, new_attns], self._attn_size, True)
      output = self._linear2([cell_output, new_attns])
    new_attn_states = array_ops.concat(
        [new_attn_states, array_ops.expand_dims(output, 1)], 1)
    new_attn_states = array_ops.reshape(
        new_attn_states, [-1, self._attn_length * self._attn_size])
    new_state = (new_state, new_attns, new_attn_states)
    if not self._state_is_tuple:
      new_state = array_ops.concat(list(new_state), 1)
    return output, new_state

  def _attention(self, query, attn_states):
    conv2d = nn_ops.conv2d
    reduce_sum = math_ops.reduce_sum
    softmax = nn_ops.softmax
    tanh = math_ops.tanh

    with vs.variable_scope("attention"):
      k = vs.get_variable("attn_w",
                          [1, 1, self._attn_size, self._attn_vec_size])
      v = vs.get_variable("attn_v", [self._attn_vec_size])
      hidden = array_ops.reshape(attn_states,
                                 [-1, self._attn_length, 1, self._attn_size])
      hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME")
      if self._linear3 is None:
        self._linear3 = _Linear(query, self._attn_vec_size, True)
      y = self._linear3(query)
      y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size])
      s = reduce_sum(v * tanh(hidden_features + y), [2, 3])
      a = softmax(s)
      d = reduce_sum(
          array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2])
      new_attns = array_ops.reshape(d, [-1, self._attn_size])
      new_attn_states = array_ops.slice(attn_states, [0, 1, 0], [-1, -1, -1])
      return new_attns, new_attn_states


class HighwayWrapper(rnn_cell_impl.RNNCell):
  """RNNCell wrapper that adds highway connection on cell input and output.

  Based on:
    R. K. Srivastava, K. Greff, and J. Schmidhuber, "Highway networks",
    arXiv preprint arXiv:1505.00387, 2015.
    https://arxiv.org/abs/1505.00387
  """

  def __init__(self,
               cell,
               couple_carry_transform_gates=True,
               carry_bias_init=1.0):
    """Constructs a `HighwayWrapper` for `cell`.

    Args:
      cell: An instance of `RNNCell`.
      couple_carry_transform_gates: boolean, should the Carry and Transform gate
        be coupled.
      carry_bias_init: float, carry gates bias initialization.
    """
    self._cell = cell
    self._couple_carry_transform_gates = couple_carry_transform_gates
    self._carry_bias_init = carry_bias_init

  @property
  def state_size(self):
    return self._cell.state_size

  @property
  def output_size(self):
    return self._cell.output_size

  def zero_state(self, batch_size, dtype):
    with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
      return self._cell.zero_state(batch_size, dtype)

  def _highway(self, inp, out):
    input_size = inp.get_shape().with_rank(2)[1].value
    carry_weight = vs.get_variable("carry_w", [input_size, input_size])
    carry_bias = vs.get_variable(
        "carry_b", [input_size],
        initializer=init_ops.constant_initializer(self._carry_bias_init))
    carry = math_ops.sigmoid(nn_ops.xw_plus_b(inp, carry_weight, carry_bias))
    if self._couple_carry_transform_gates:
      transform = 1 - carry
    else:
      transform_weight = vs.get_variable("transform_w",
                                         [input_size, input_size])
      transform_bias = vs.get_variable(
          "transform_b", [input_size],
          initializer=init_ops.constant_initializer(-self._carry_bias_init))
      transform = math_ops.sigmoid(
          nn_ops.xw_plus_b(inp, transform_weight, transform_bias))
    return inp * carry + out * transform

  def __call__(self, inputs, state, scope=None):
    """Run the cell and add its inputs to its outputs.

    Args:
      inputs: cell inputs.
      state: cell state.
      scope: optional cell scope.

    Returns:
      Tuple of cell outputs and new state.

    Raises:
      TypeError: If cell inputs and outputs have different structure (type).
      ValueError: If cell inputs and outputs have different structure (value).
    """
    outputs, new_state = self._cell(inputs, state, scope=scope)
    nest.assert_same_structure(inputs, outputs)

    # Ensure shapes match
    def assert_shape_match(inp, out):
      inp.get_shape().assert_is_compatible_with(out.get_shape())

    nest.map_structure(assert_shape_match, inputs, outputs)
    res_outputs = nest.map_structure(self._highway, inputs, outputs)
    return (res_outputs, new_state)


class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
  """LSTM unit with layer normalization and recurrent dropout.

  This class adds layer normalization and recurrent dropout to a
  basic LSTM unit. Layer normalization implementation is based on:

    https://arxiv.org/abs/1607.06450.

  "Layer Normalization"
  Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton

  and is applied before the internal nonlinearities.
  Recurrent dropout is base on:

    https://arxiv.org/abs/1603.05118

  "Recurrent Dropout without Memory Loss"
  Stanislau Semeniuta, Aliaksei Severyn, Erhardt Barth.
  """

  def __init__(self,
               num_units,
               forget_bias=1.0,
               input_size=None,
               activation=math_ops.tanh,
               layer_norm=True,
               norm_gain=1.0,
               norm_shift=0.0,
               dropout_keep_prob=1.0,
               dropout_prob_seed=None,
               reuse=None):
    """Initializes the basic LSTM cell.

    Args:
      num_units: int, The number of units in the LSTM cell.
      forget_bias: float, The bias added to forget gates (see above).
      input_size: Deprecated and unused.
      activation: Activation function of the inner states.
      layer_norm: If `True`, layer normalization will be applied.
      norm_gain: float, The layer normalization gain initial value. If
        `layer_norm` has been set to `False`, this argument will be ignored.
      norm_shift: float, The layer normalization shift initial value. If
        `layer_norm` has been set to `False`, this argument will be ignored.
      dropout_keep_prob: unit Tensor or float between 0 and 1 representing the
        recurrent dropout probability value. If float and 1.0, no dropout will
        be applied.
      dropout_prob_seed: (optional) integer, the randomness seed.
      reuse: (optional) Python boolean describing whether to reuse variables
        in an existing scope.  If not `True`, and the existing scope already has
        the given variables, an error is raised.
    """
    super(LayerNormBasicLSTMCell, self).__init__(_reuse=reuse)

    if input_size is not None:
      logging.warn("%s: The input_size parameter is deprecated.", self)

    self._num_units = num_units
    self._activation = activation
    self._forget_bias = forget_bias
    self._keep_prob = dropout_keep_prob
    self._seed = dropout_prob_seed
    self._layer_norm = layer_norm
    self._norm_gain = norm_gain
    self._norm_shift = norm_shift
    self._reuse = reuse

  @property
  def state_size(self):
    return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)

  @property
  def output_size(self):
    return self._num_units

  def _norm(self, inp, scope, dtype=dtypes.float32):
    shape = inp.get_shape()[-1:]
    gamma_init = init_ops.constant_initializer(self._norm_gain)
    beta_init = init_ops.constant_initializer(self._norm_shift)
    with vs.variable_scope(scope):
      # Initialize beta and gamma for use by layer_norm.
      vs.get_variable("gamma", shape=shape, initializer=gamma_init, dtype=dtype)
      vs.get_variable("beta", shape=shape, initializer=beta_init, dtype=dtype)
    normalized = layers.layer_norm(inp, reuse=True, scope=scope)
    return normalized

  def _linear(self, args):
    out_size = 4 * self._num_units
    proj_size = args.get_shape()[-1]
    dtype = args.dtype
    weights = vs.get_variable("kernel", [proj_size, out_size], dtype=dtype)
    out = math_ops.matmul(args, weights)
    if not self._layer_norm:
      bias = vs.get_variable("bias", [out_size], dtype=dtype)
      out = nn_ops.bias_add(out, bias)
    return out

  def call(self, inputs, state):
    """LSTM cell with layer normalization and recurrent dropout."""
    c, h = state
    args = array_ops.concat([inputs, h], 1)
    concat = self._linear(args)
    dtype = args.dtype

    i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
    if self._layer_norm:
      i = self._norm(i, "input", dtype=dtype)
      j = self._norm(j, "transform", dtype=dtype)
      f = self._norm(f, "forget", dtype=dtype)
      o = self._norm(o, "output", dtype=dtype)

    g = self._activation(j)
    if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
      g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)

    new_c = (
        c * math_ops.sigmoid(f + self._forget_bias) + math_ops.sigmoid(i) * g)
    if self._layer_norm:
      new_c = self._norm(new_c, "state", dtype=dtype)
    new_h = self._activation(new_c) * math_ops.sigmoid(o)

    new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
    return new_h, new_state


class NASCell(rnn_cell_impl.RNNCell):
  """Neural Architecture Search (NAS) recurrent network cell.

  This implements the recurrent cell from the paper:

    https://arxiv.org/abs/1611.01578

  Barret Zoph and Quoc V. Le.
  "Neural Architecture Search with Reinforcement Learning" Proc. ICLR 2017.

  The class uses an optional projection layer.
  """

  def __init__(self, num_units, num_proj=None, use_biases=False, reuse=None):
    """Initialize the parameters for a NAS cell.

    Args:
      num_units: int, The number of units in the NAS cell
      num_proj: (optional) int, The output dimensionality for the projection
        matrices.  If None, no projection is performed.
      use_biases: (optional) bool, If True then use biases within the cell. This
        is False by default.
      reuse: (optional) Python boolean describing whether to reuse variables
        in an existing scope.  If not `True`, and the existing scope already has
        the given variables, an error is raised.
    """
    super(NASCell, self).__init__(_reuse=reuse)
    self._num_units = num_units
    self._num_proj = num_proj
    self._use_biases = use_biases
    self._reuse = reuse

    if num_proj is not None:
      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
      self._output_size = num_proj
    else:
      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
      self._output_size = num_units

  @property
  def state_size(self):
    return self._state_size

  @property
  def output_size(self):
    return self._output_size

  def call(self, inputs, state):
    """Run one step of NAS Cell.

    Args:
      inputs: input Tensor, 2D, batch x num_units.
      state: This must be a tuple of state Tensors, both `2-D`, with column
        sizes `c_state` and `m_state`.

    Returns:
      A tuple containing:
      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
        NAS Cell after reading `inputs` when previous state was `state`.
        Here output_dim is:
           num_proj if num_proj was set,
           num_units otherwise.
      - Tensor(s) representing the new state of NAS Cell after reading `inputs`
        when the previous state was `state`.  Same type and shape(s) as `state`.

    Raises:
      ValueError: If input size cannot be inferred from inputs via
        static shape inference.
    """
    sigmoid = math_ops.sigmoid
    tanh = math_ops.tanh
    relu = nn_ops.relu

    num_proj = self._num_units if self._num_proj is None else self._num_proj

    (c_prev, m_prev) = state

    dtype = inputs.dtype
    input_size = inputs.get_shape().with_rank(2)[1]
    if input_size.value is None:
      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
    # Variables for the NAS cell. W_m is all matrices multiplying the
    # hiddenstate and W_inputs is all matrices multiplying the inputs.
    concat_w_m = vs.get_variable("recurrent_kernel",
                                 [num_proj, 8 * self._num_units], dtype)
    concat_w_inputs = vs.get_variable(
        "kernel", [input_size.value, 8 * self._num_units], dtype)

    m_matrix = math_ops.matmul(m_prev, concat_w_m)
    inputs_matrix = math_ops.matmul(inputs, concat_w_inputs)

    if self._use_biases:
      b = vs.get_variable(
          "bias",
          shape=[8 * self._num_units],
          initializer=init_ops.zeros_initializer(),
          dtype=dtype)
      m_matrix = nn_ops.bias_add(m_matrix, b)

    # The NAS cell branches into 8 different splits for both the hiddenstate
    # and the input
    m_matrix_splits = array_ops.split(
        axis=1, num_or_size_splits=8, value=m_matrix)
    inputs_matrix_splits = array_ops.split(
        axis=1, num_or_size_splits=8, value=inputs_matrix)

    # First layer
    layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0])
    layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1])
    layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2])
    layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3])
    layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4])
    layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5])
    layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6])
    layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7])

    # Second layer
    l2_0 = tanh(layer1_0 * layer1_1)
    l2_1 = tanh(layer1_2 + layer1_3)
    l2_2 = tanh(layer1_4 * layer1_5)
    l2_3 = sigmoid(layer1_6 + layer1_7)

    # Inject the cell
    l2_0 = tanh(l2_0 + c_prev)

    # Third layer
    l3_0_pre = l2_0 * l2_1
    new_c = l3_0_pre  # create new cell
    l3_0 = l3_0_pre
    l3_1 = tanh(l2_2 + l2_3)

    # Final layer
    new_m = tanh(l3_0 * l3_1)

    # Projection layer if specified
    if self._num_proj is not None:
      concat_w_proj = vs.get_variable("projection_weights",
                                      [self._num_units, self._num_proj], dtype)
      new_m = math_ops.matmul(new_m, concat_w_proj)

    new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_m)
    return new_m, new_state


class UGRNNCell(rnn_cell_impl.RNNCell):
  """Update Gate Recurrent Neural Network (UGRNN) cell.

  Compromise between a LSTM/GRU and a vanilla RNN.  There is only one
  gate, and that is to determine whether the unit should be
  integrating or computing instantaneously.  This is the recurrent
  idea of the feedforward Highway Network.

  This implements the recurrent cell from the paper:

    https://arxiv.org/abs/1611.09913

  Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo.
  "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017.
  """

  def __init__(self,
               num_units,
               initializer=None,
               forget_bias=1.0,
               activation=math_ops.tanh,
               reuse=None):
    """Initialize the parameters for an UGRNN cell.

    Args:
      num_units: int, The number of units in the UGRNN cell
      initializer: (optional) The initializer to use for the weight matrices.
      forget_bias: (optional) float, default 1.0, The initial bias of the
        forget gate, used to reduce the scale of forgetting at the beginning
        of the training.
      activation: (optional) Activation function of the inner states.
        Default is `tf.tanh`.
      reuse: (optional) Python boolean describing whether to reuse variables
        in an existing scope.  If not `True`, and the existing scope already has
        the given variables, an error is raised.
    """
    super(UGRNNCell, self).__init__(_reuse=reuse)
    self._num_units = num_units
    self._initializer = initializer
    self._forget_bias = forget_bias
    self._activation = activation
    self._reuse = reuse
    self._linear = None

  @property
  def state_size(self):
    return self._num_units

  @property
  def output_size(self):
    return self._num_units

  def call(self, inputs, state):
    """Run one step of UGRNN.

    Args:
      inputs: input Tensor, 2D, batch x input size.
      state: state Tensor, 2D, batch x num units.

    Returns:
      new_output: batch x num units, Tensor representing the output of the UGRNN
        after reading `inputs` when previous state was `state`. Identical to
        `new_state`.
      new_state: batch x num units, Tensor representing the state of the UGRNN
        after reading `inputs` when previous state was `state`.

    Raises:
      ValueError: If input size cannot be inferred from inputs via
        static shape inference.
    """
    sigmoid = math_ops.sigmoid

    input_size = inputs.get_shape().with_rank(2)[1]
    if input_size.value is None:
      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")

    with vs.variable_scope(
        vs.get_variable_scope(), initializer=self._initializer):
      cell_inputs = array_ops.concat([inputs, state], 1)
      if self._linear is None:
        self._linear = _Linear(cell_inputs, 2 * self._num_units, True)
      rnn_matrix = self._linear(cell_inputs)

      [g_act, c_act] = array_ops.split(
          axis=1, num_or_size_splits=2, value=rnn_matrix)

      c = self._activation(c_act)
      g = sigmoid(g_act + self._forget_bias)
      new_state = g * state + (1.0 - g) * c
      new_output = new_state

    return new_output, new_state


class IntersectionRNNCell(rnn_cell_impl.RNNCell):
  """Intersection Recurrent Neural Network (+RNN) cell.

  Architecture with coupled recurrent gate as well as coupled depth
  gate, designed to improve information flow through stacked RNNs. As the
  architecture uses depth gating, the dimensionality of the depth
  output (y) also should not change through depth (input size == output size).
  To achieve this, the first layer of a stacked Intersection RNN projects
  the inputs to N (num units) dimensions. Therefore when initializing an
  IntersectionRNNCell, one should set `num_in_proj = N` for the first layer
  and use default settings for subsequent layers.

  This implements the recurrent cell from the paper:

    https://arxiv.org/abs/1611.09913

  Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo.
  "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017.

  The Intersection RNN is built for use in deeply stacked
  RNNs so it may not achieve best performance with depth 1.
  """

  def __init__(self,
               num_units,
               num_in_proj=None,
               initializer=None,
               forget_bias=1.0,
               y_activation=nn_ops.relu,
               reuse=None):
    """Initialize the parameters for an +RNN cell.

    Args:
      num_units: int, The number of units in the +RNN cell
      num_in_proj: (optional) int, The input dimensionality for the RNN.
        If creating the first layer of an +RNN, this should be set to
        `num_units`. Otherwise, this should be set to `None` (default).
        If `None`, dimensionality of `inputs` should be equal to `num_units`,
        otherwise ValueError is thrown.
      initializer: (optional) The initializer to use for the weight matrices.
      forget_bias: (optional) float, default 1.0, The initial bias of the
        forget gates, used to reduce the scale of forgetting at the beginning
        of the training.
      y_activation: (optional) Activation function of the states passed
        through depth. Default is 'tf.nn.relu`.
      reuse: (optional) Python boolean describing whether to reuse variables
        in an existing scope.  If not `True`, and the existing scope already has
        the given variables, an error is raised.
    """
    super(IntersectionRNNCell, self).__init__(_reuse=reuse)
    self._num_units = num_units
    self._initializer = initializer
    self._forget_bias = forget_bias
    self._num_input_proj = num_in_proj
    self._y_activation = y_activation
    self._reuse = reuse
    self._linear1 = None
    self._linear2 = None

  @property
  def state_size(self):
    return self._num_units

  @property
  def output_size(self):
    return self._num_units

  def call(self, inputs, state):
    """Run one step of the Intersection RNN.

    Args:
      inputs: input Tensor, 2D, batch x input size.
      state: state Tensor, 2D, batch x num units.

    Returns:
      new_y: batch x num units, Tensor representing the output of the +RNN
        after reading `inputs` when previous state was `state`.
      new_state: batch x num units, Tensor representing the state of the +RNN
        after reading `inputs` when previous state was `state`.

    Raises:
      ValueError: If input size cannot be inferred from `inputs` via
        static shape inference.
      ValueError: If input size != output size (these must be equal when
        using the Intersection RNN).
    """
    sigmoid = math_ops.sigmoid
    tanh = math_ops.tanh

    input_size = inputs.get_shape().with_rank(2)[1]
    if input_size.value is None:
      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")

    with vs.variable_scope(
        vs.get_variable_scope(), initializer=self._initializer):
      # read-in projections (should be used for first layer in deep +RNN
      # to transform size of inputs from I --> N)
      if input_size.value != self._num_units:
        if self._num_input_proj:
          with vs.variable_scope("in_projection"):
            if self._linear1 is None:
              self._linear1 = _Linear(inputs, self._num_units, True)
            inputs = self._linear1(inputs)
        else:
          raise ValueError("Must have input size == output size for "
                           "Intersection RNN. To fix, num_in_proj should "
                           "be set to num_units at cell init.")

      n_dim = i_dim = self._num_units
      cell_inputs = array_ops.concat([inputs, state], 1)
      if self._linear2 is None:
        self._linear2 = _Linear(cell_inputs, 2 * n_dim + 2 * i_dim, True)
      rnn_matrix = self._linear2(cell_inputs)

      gh_act = rnn_matrix[:, :n_dim]  # b x n
      h_act = rnn_matrix[:, n_dim:2 * n_dim]  # b x n
      gy_act = rnn_matrix[:, 2 * n_dim:2 * n_dim + i_dim]  # b x i
      y_act = rnn_matrix[:, 2 * n_dim + i_dim:2 * n_dim + 2 * i_dim]  # b x i

      h = tanh(h_act)
      y = self._y_activation(y_act)
      gh = sigmoid(gh_act + self._forget_bias)
      gy = sigmoid(gy_act + self._forget_bias)

      new_state = gh * state + (1.0 - gh) * h  # passed thru time
      new_y = gy * inputs + (1.0 - gy) * y  # passed thru depth

    return new_y, new_state


_REGISTERED_OPS = None


class CompiledWrapper(rnn_cell_impl.RNNCell):
  """Wraps step execution in an XLA JIT scope."""

  def __init__(self, cell, compile_stateful=False):
    """Create CompiledWrapper cell.

    Args:
      cell: Instance of `RNNCell`.
      compile_stateful: Whether to compile stateful ops like initializers
        and random number generators (default: False).
    """
    self._cell = cell
    self._compile_stateful = compile_stateful

  @property
  def state_size(self):
    return self._cell.state_size

  @property
  def output_size(self):
    return self._cell.output_size

  def zero_state(self, batch_size, dtype):
    with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
      return self._cell.zero_state(batch_size, dtype)

  def __call__(self, inputs, state, scope=None):
    if self._compile_stateful:
      compile_ops = True
    else:

      def compile_ops(node_def):
        global _REGISTERED_OPS
        if _REGISTERED_OPS is None:
          _REGISTERED_OPS = op_def_registry.get_registered_ops()
        return not _REGISTERED_OPS[node_def.op].is_stateful

    with jit.experimental_jit_scope(compile_ops=compile_ops):
      return self._cell(inputs, state, scope=scope)


def _random_exp_initializer(minval, maxval, seed=None, dtype=dtypes.float32):
  """Returns an exponential distribution initializer.

  Args:
    minval: float or a scalar float Tensor. With value > 0. Lower bound of the
        range of random values to generate.
    maxval: float or a scalar float Tensor. With value > minval. Upper bound of
        the range of random values to generate.
    seed: An integer. Used to create random seeds.
    dtype: The data type.

  Returns:
    An initializer that generates tensors with an exponential distribution.
  """

  def _initializer(shape, dtype=dtype, partition_info=None):
    del partition_info  # Unused.
    return math_ops.exp(
        random_ops.random_uniform(
            shape, math_ops.log(minval), math_ops.log(maxval), dtype,
            seed=seed))

  return _initializer


class PhasedLSTMCell(rnn_cell_impl.RNNCell):
  """Phased LSTM recurrent network cell.

  https://arxiv.org/pdf/1610.09513v1.pdf
  """

  def __init__(self,
               num_units,
               use_peepholes=False,
               leak=0.001,
               ratio_on=0.1,
               trainable_ratio_on=True,
               period_init_min=1.0,
               period_init_max=1000.0,
               reuse=None):
    """Initialize the Phased LSTM cell.

    Args:
      num_units: int, The number of units in the Phased LSTM cell.
      use_peepholes: bool, set True to enable peephole connections.
      leak: float or scalar float Tensor with value in [0, 1]. Leak applied
          during training.
      ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the
          period during which the gates are open.
      trainable_ratio_on: bool, weather ratio_on is trainable.
      period_init_min: float or scalar float Tensor. With value > 0.
          Minimum value of the initialized period.
          The period values are initialized by drawing from the distribution:
          e^U(log(period_init_min), log(period_init_max))
          Where U(.,.) is the uniform distribution.
      period_init_max: float or scalar float Tensor.
          With value > period_init_min. Maximum value of the initialized period.
      reuse: (optional) Python boolean describing whether to reuse variables
        in an existing scope. If not `True`, and the existing scope already has
        the given variables, an error is raised.
    """
    super(PhasedLSTMCell, self).__init__(_reuse=reuse)
    self._num_units = num_units
    self._use_peepholes = use_peepholes
    self._leak = leak
    self._ratio_on = ratio_on
    self._trainable_ratio_on = trainable_ratio_on
    self._period_init_min = period_init_min
    self._period_init_max = period_init_max
    self._reuse = reuse
    self._linear1 = None
    self._linear2 = None
    self._linear3 = None

  @property
  def state_size(self):
    return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)

  @property
  def output_size(self):
    return self._num_units

  def _mod(self, x, y):
    """Modulo function that propagates x gradients."""
    return array_ops.stop_gradient(math_ops.mod(x, y) - x) + x

  def _get_cycle_ratio(self, time, phase, period):
    """Compute the cycle ratio in the dtype of the time."""
    phase_casted = math_ops.cast(phase, dtype=time.dtype)
    period_casted = math_ops.cast(period, dtype=time.dtype)
    shifted_time = time - phase_casted
    cycle_ratio = self._mod(shifted_time, period_casted) / period_casted
    return math_ops.cast(cycle_ratio, dtype=dtypes.float32)

  def call(self, inputs, state):
    """Phased LSTM Cell.

    Args:
      inputs: A tuple of 2 Tensor.
         The first Tensor has shape [batch, 1], and type float32 or float64.
         It stores the time.
         The second Tensor has shape [batch, features_size], and type float32.
         It stores the features.
      state: rnn_cell_impl.LSTMStateTuple, state from previous timestep.

    Returns:
      A tuple containing:
      - A Tensor of float32, and shape [batch_size, num_units], representing the
        output of the cell.
      - A rnn_cell_impl.LSTMStateTuple, containing 2 Tensors of float32, shape
        [batch_size, num_units], representing the new state and the output.
    """
    (c_prev, h_prev) = state
    (time, x) = inputs

    in_mask_gates = [x, h_prev]
    if self._use_peepholes:
      in_mask_gates.append(c_prev)

    with vs.variable_scope("mask_gates"):
      if self._linear1 is None:
        self._linear1 = _Linear(in_mask_gates, 2 * self._num_units, True)

      mask_gates = math_ops.sigmoid(self._linear1(in_mask_gates))
      [input_gate, forget_gate] = array_ops.split(
          axis=1, num_or_size_splits=2, value=mask_gates)

    with vs.variable_scope("new_input"):
      if self._linear2 is None:
        self._linear2 = _Linear([x, h_prev], self._num_units, True)
      new_input = math_ops.tanh(self._linear2([x, h_prev]))

    new_c = (c_prev * forget_gate + input_gate * new_input)

    in_out_gate = [x, h_prev]
    if self._use_peepholes:
      in_out_gate.append(new_c)

    with vs.variable_scope("output_gate"):
      if self._linear3 is None:
        self._linear3 = _Linear(in_out_gate, self._num_units, True)
      output_gate = math_ops.sigmoid(self._linear3(in_out_gate))

    new_h = math_ops.tanh(new_c) * output_gate

    period = vs.get_variable(
        "period", [self._num_units],
        initializer=_random_exp_initializer(self._period_init_min,
                                            self._period_init_max))
    phase = vs.get_variable(
        "phase", [self._num_units],
        initializer=init_ops.random_uniform_initializer(0.,
                                                        period.initial_value))
    ratio_on = vs.get_variable(
        "ratio_on", [self._num_units],
        initializer=init_ops.constant_initializer(self._ratio_on),
        trainable=self._trainable_ratio_on)

    cycle_ratio = self._get_cycle_ratio(time, phase, period)

    k_up = 2 * cycle_ratio / ratio_on
    k_down = 2 - k_up
    k_closed = self._leak * cycle_ratio

    k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed)
    k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k)

    new_c = k * new_c + (1 - k) * c_prev
    new_h = k * new_h + (1 - k) * h_prev

    new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)

    return new_h, new_state


class ConvLSTMCell(rnn_cell_impl.RNNCell):
  """Convolutional LSTM recurrent network cell.

  https://arxiv.org/pdf/1506.04214v1.pdf
  """

  def __init__(self,
               conv_ndims,
               input_shape,
               output_channels,
               kernel_shape,
               use_bias=True,
               skip_connection=False,
               forget_bias=1.0,
               initializers=None,
               name="conv_lstm_cell"):
    """Construct ConvLSTMCell.

    Args:
      conv_ndims: Convolution dimensionality (1, 2 or 3).
      input_shape: Shape of the input as int tuple, excluding the batch size.
      output_channels: int, number of output channels of the conv LSTM.
      kernel_shape: Shape of kernel as in tuple (of size 1,2 or 3).
      use_bias: (bool) Use bias in convolutions.
      skip_connection: If set to `True`, concatenate the input to the
        output of the conv LSTM. Default: `False`.
      forget_bias: Forget bias.
      initializers: Unused.
      name: Name of the module.

    Raises:
      ValueError: If `skip_connection` is `True` and stride is different from 1
        or if `input_shape` is incompatible with `conv_ndims`.
    """
    super(ConvLSTMCell, self).__init__(name=name)

    if conv_ndims != len(input_shape) - 1:
      raise ValueError("Invalid input_shape {} for conv_ndims={}.".format(
          input_shape, conv_ndims))

    self._conv_ndims = conv_ndims
    self._input_shape = input_shape
    self._output_channels = output_channels
    self._kernel_shape = kernel_shape
    self._use_bias = use_bias
    self._forget_bias = forget_bias
    self._skip_connection = skip_connection

    self._total_output_channels = output_channels
    if self._skip_connection:
      self._total_output_channels += self._input_shape[-1]

    state_size = tensor_shape.TensorShape(
        self._input_shape[:-1] + [self._output_channels])
    self._state_size = rnn_cell_impl.LSTMStateTuple(state_size, state_size)
    self._output_size = tensor_shape.TensorShape(
        self._input_shape[:-1] + [self._total_output_channels])

  @property
  def output_size(self):
    return self._output_size

  @property
  def state_size(self):
    return self._state_size

  def call(self, inputs, state, scope=None):
    cell, hidden = state
    new_hidden = _conv([inputs, hidden], self._kernel_shape,
                       4 * self._output_channels, self._use_bias)
    gates = array_ops.split(
        value=new_hidden, num_or_size_splits=4, axis=self._conv_ndims + 1)

    input_gate, new_input, forget_gate, output_gate = gates
    new_cell = math_ops.sigmoid(forget_gate + self._forget_bias) * cell
    new_cell += math_ops.sigmoid(input_gate) * math_ops.tanh(new_input)
    output = math_ops.tanh(new_cell) * math_ops.sigmoid(output_gate)

    if self._skip_connection:
      output = array_ops.concat([output, inputs], axis=-1)
    new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output)
    return output, new_state


class Conv1DLSTMCell(ConvLSTMCell):
  """1D Convolutional LSTM recurrent network cell.

  https://arxiv.org/pdf/1506.04214v1.pdf
  """

  def __init__(self, name="conv_1d_lstm_cell", **kwargs):
    """Construct Conv1DLSTM. See `ConvLSTMCell` for more details."""
    super(Conv1DLSTMCell, self).__init__(conv_ndims=1, name=name, **kwargs)


class Conv2DLSTMCell(ConvLSTMCell):
  """2D Convolutional LSTM recurrent network cell.

  https://arxiv.org/pdf/1506.04214v1.pdf
  """

  def __init__(self, name="conv_2d_lstm_cell", **kwargs):
    """Construct Conv2DLSTM. See `ConvLSTMCell` for more details."""
    super(Conv2DLSTMCell, self).__init__(conv_ndims=2, name=name, **kwargs)


class Conv3DLSTMCell(ConvLSTMCell):
  """3D Convolutional LSTM recurrent network cell.

  https://arxiv.org/pdf/1506.04214v1.pdf
  """

  def __init__(self, name="conv_3d_lstm_cell", **kwargs):
    """Construct Conv3DLSTM. See `ConvLSTMCell` for more details."""
    super(Conv3DLSTMCell, self).__init__(conv_ndims=3, name=name, **kwargs)


def _conv(args, filter_size, num_features, bias, bias_start=0.0):
  """Convolution.

  Args:
    args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D,
    batch x n, Tensors.
    filter_size: int tuple of filter height and width.
    num_features: int, number of features.
    bias: Whether to use biases in the convolution layer.
    bias_start: starting value to initialize the bias; 0 by default.

  Returns:
    A 3D, 4D, or 5D Tensor with shape [batch ... num_features]

  Raises:
    ValueError: if some of the arguments has unspecified or wrong shape.
  """

  # Calculate the total size of arguments on dimension 1.
  total_arg_size_depth = 0
  shapes = [a.get_shape().as_list() for a in args]
  shape_length = len(shapes[0])
  for shape in shapes:
    if len(shape) not in [3, 4, 5]:
      raise ValueError("Conv Linear expects 3D, 4D "
                       "or 5D arguments: %s" % str(shapes))
    if len(shape) != len(shapes[0]):
      raise ValueError("Conv Linear expects all args "
                       "to be of same Dimension: %s" % str(shapes))
    else:
      total_arg_size_depth += shape[-1]
  dtype = [a.dtype for a in args][0]

  # determine correct conv operation
  if shape_length == 3:
    conv_op = nn_ops.conv1d
    strides = 1
  elif shape_length == 4:
    conv_op = nn_ops.conv2d
    strides = shape_length * [1]
  elif shape_length == 5:
    conv_op = nn_ops.conv3d
    strides = shape_length * [1]

  # Now the computation.
  kernel = vs.get_variable(
      "kernel", filter_size + [total_arg_size_depth, num_features], dtype=dtype)
  if len(args) == 1:
    res = conv_op(args[0], kernel, strides, padding="SAME")
  else:
    res = conv_op(
        array_ops.concat(axis=shape_length - 1, values=args),
        kernel,
        strides,
        padding="SAME")
  if not bias:
    return res
  bias_term = vs.get_variable(
      "biases", [num_features],
      dtype=dtype,
      initializer=init_ops.constant_initializer(bias_start, dtype=dtype))
  return res + bias_term


class GLSTMCell(rnn_cell_impl.RNNCell):
  """Group LSTM cell (G-LSTM).

  The implementation is based on:

    https://arxiv.org/abs/1703.10722

  O. Kuchaiev and B. Ginsburg
  "Factorization Tricks for LSTM Networks", ICLR 2017 workshop.

  In brief, a G-LSTM cell consists of one LSTM sub-cell per group, where each
  sub-cell operates on an evenly-sized sub-vector of the input and produces an
  evenly-sized sub-vector of the output.  For example, a G-LSTM cell with 128
  units and 4 groups consists of 4 LSTMs sub-cells with 32 units each.  If that
  G-LSTM cell is fed a 200-dim input, then each sub-cell receives a 50-dim part
  of the input and produces a 32-dim part of the output.
  """

  def __init__(self,
               num_units,
               initializer=None,
               num_proj=None,
               number_of_groups=1,
               forget_bias=1.0,
               activation=math_ops.tanh,
               reuse=None):
    """Initialize the parameters of G-LSTM cell.

    Args:
      num_units: int, The number of units in the G-LSTM cell
      initializer: (optional) The initializer to use for the weight and
        projection matrices.
      num_proj: (optional) int, The output dimensionality for the projection
        matrices.  If None, no projection is performed.
      number_of_groups: (optional) int, number of groups to use.
        If `number_of_groups` is 1, then it should be equivalent to LSTM cell
      forget_bias: Biases of the forget gate are initialized by default to 1
        in order to reduce the scale of forgetting at the beginning of
        the training.
      activation: Activation function of the inner states.
      reuse: (optional) Python boolean describing whether to reuse variables
        in an existing scope.  If not `True`, and the existing scope already
        has the given variables, an error is raised.

    Raises:
      ValueError: If `num_units` or `num_proj` is not divisible by
        `number_of_groups`.
    """
    super(GLSTMCell, self).__init__(_reuse=reuse)
    self._num_units = num_units
    self._initializer = initializer
    self._num_proj = num_proj
    self._forget_bias = forget_bias
    self._activation = activation
    self._number_of_groups = number_of_groups

    if self._num_units % self._number_of_groups != 0:
      raise ValueError("num_units must be divisible by number_of_groups")
    if self._num_proj:
      if self._num_proj % self._number_of_groups != 0:
        raise ValueError("num_proj must be divisible by number_of_groups")
      self._group_shape = [
          int(self._num_proj / self._number_of_groups),
          int(self._num_units / self._number_of_groups)
      ]
    else:
      self._group_shape = [
          int(self._num_units / self._number_of_groups),
          int(self._num_units / self._number_of_groups)
      ]

    if num_proj:
      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
      self._output_size = num_proj
    else:
      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
      self._output_size = num_units
    self._linear1 = [None] * number_of_groups
    self._linear2 = None

  @property
  def state_size(self):
    return self._state_size

  @property
  def output_size(self):
    return self._output_size

  def _get_input_for_group(self, inputs, group_id, group_size):
    """Slices inputs into groups to prepare for processing by cell's groups.

    Args:
      inputs: cell input or it's previous state,
              a Tensor, 2D, [batch x num_units]
      group_id: group id, a Scalar, for which to prepare input
      group_size: size of the group

    Returns:
      subset of inputs corresponding to group "group_id",
      a Tensor, 2D, [batch x num_units/number_of_groups]
    """
    return array_ops.slice(
        input_=inputs,
        begin=[0, group_id * group_size],
        size=[self._batch_size, group_size],
        name=("GLSTM_group%d_input_generation" % group_id))

  def call(self, inputs, state):
    """Run one step of G-LSTM.

    Args:
      inputs: input Tensor, 2D, [batch x num_inputs].  num_inputs must be
        statically-known and evenly divisible into groups.  The innermost
        vectors of the inputs are split into evenly-sized sub-vectors and fed
        into the per-group LSTM sub-cells.
      state: this must be a tuple of state Tensors, both `2-D`, with column
        sizes `c_state` and `m_state`.

    Returns:
      A tuple containing:

      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
        G-LSTM after reading `inputs` when previous state was `state`.
        Here output_dim is:
           num_proj if num_proj was set,
           num_units otherwise.
      - LSTMStateTuple representing the new state of G-LSTM cell
        after reading `inputs` when the previous state was `state`.

    Raises:
      ValueError: If input size cannot be inferred from inputs via
        static shape inference, or if the input shape is incompatible
        with the number of groups.
    """
    (c_prev, m_prev) = state

    self._batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]

    # If the input size is statically-known, calculate and validate its group
    # size.  Otherwise, use the output group size.
    input_size = inputs.shape[1].value
    if input_size is None:
      raise ValueError("input size must be statically known")
    if input_size % self._number_of_groups != 0:
      raise ValueError(
          "input size (%d) must be divisible by number_of_groups (%d)" %
          (input_size, self._number_of_groups))
    input_group_size = int(input_size / self._number_of_groups)

    dtype = inputs.dtype
    scope = vs.get_variable_scope()
    with vs.variable_scope(scope, initializer=self._initializer):
      i_parts = []
      j_parts = []
      f_parts = []
      o_parts = []

      for group_id in range(self._number_of_groups):
        with vs.variable_scope("group%d" % group_id):
          x_g_id = array_ops.concat(
              [
                  self._get_input_for_group(inputs, group_id, input_group_size),
                  self._get_input_for_group(m_prev, group_id,
                                            self._group_shape[0])
              ],
              axis=1)
          linear = self._linear1[group_id]
          if linear is None:
            linear = _Linear(x_g_id, 4 * self._group_shape[1], False)
            self._linear1[group_id] = linear
          R_k = linear(x_g_id)  # pylint: disable=invalid-name
          i_k, j_k, f_k, o_k = array_ops.split(R_k, 4, 1)

        i_parts.append(i_k)
        j_parts.append(j_k)
        f_parts.append(f_k)
        o_parts.append(o_k)

      bi = vs.get_variable(
          name="bias_i",
          shape=[self._num_units],
          dtype=dtype,
          initializer=init_ops.constant_initializer(0.0, dtype=dtype))
      bj = vs.get_variable(
          name="bias_j",
          shape=[self._num_units],
          dtype=dtype,
          initializer=init_ops.constant_initializer(0.0, dtype=dtype))
      bf = vs.get_variable(
          name="bias_f",
          shape=[self._num_units],
          dtype=dtype,
          initializer=init_ops.constant_initializer(0.0, dtype=dtype))
      bo = vs.get_variable(
          name="bias_o",
          shape=[self._num_units],
          dtype=dtype,
          initializer=init_ops.constant_initializer(0.0, dtype=dtype))

      i = nn_ops.bias_add(array_ops.concat(i_parts, axis=1), bi)
      j = nn_ops.bias_add(array_ops.concat(j_parts, axis=1), bj)
      f = nn_ops.bias_add(array_ops.concat(f_parts, axis=1), bf)
      o = nn_ops.bias_add(array_ops.concat(o_parts, axis=1), bo)

    c = (
        math_ops.sigmoid(f + self._forget_bias) * c_prev +
        math_ops.sigmoid(i) * math_ops.tanh(j))
    m = math_ops.sigmoid(o) * self._activation(c)

    if self._num_proj is not None:
      with vs.variable_scope("projection"):
        if self._linear2 is None:
          self._linear2 = _Linear(m, self._num_proj, False)
        m = self._linear2(m)

    new_state = rnn_cell_impl.LSTMStateTuple(c, m)
    return m, new_state


class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
  """Long short-term memory unit (LSTM) recurrent network cell.

  The default non-peephole implementation is based on:

    http://www.bioinf.jku.at/publications/older/2604.pdf

  S. Hochreiter and J. Schmidhuber.
  "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.

  The peephole implementation is based on:

    https://research.google.com/pubs/archive/43905.pdf

  Hasim Sak, Andrew Senior, and Francoise Beaufays.
  "Long short-term memory recurrent neural network architectures for
   large scale acoustic modeling." INTERSPEECH, 2014.

  The class uses optional peep-hole connections, optional cell clipping, and
  an optional projection layer.

  Layer normalization implementation is based on:

    https://arxiv.org/abs/1607.06450.

  "Layer Normalization"
  Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton

  and is applied before the internal nonlinearities.

  """

  def __init__(self,
               num_units,
               use_peepholes=False,
               cell_clip=None,
               initializer=None,
               num_proj=None,
               proj_clip=None,
               forget_bias=1.0,
               activation=None,
               layer_norm=False,
               norm_gain=1.0,
               norm_shift=0.0,
               reuse=None):
    """Initialize the parameters for an LSTM cell.

    Args:
      num_units: int, The number of units in the LSTM cell
      use_peepholes: bool, set True to enable diagonal/peephole connections.
      cell_clip: (optional) A float value, if provided the cell state is clipped
        by this value prior to the cell output activation.
      initializer: (optional) The initializer to use for the weight and
        projection matrices.
      num_proj: (optional) int, The output dimensionality for the projection
        matrices.  If None, no projection is performed.
      proj_clip: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
        provided, then the projected values are clipped elementwise to within
        `[-proj_clip, proj_clip]`.
      forget_bias: Biases of the forget gate are initialized by default to 1
        in order to reduce the scale of forgetting at the beginning of
        the training. Must set it manually to `0.0` when restoring from
        CudnnLSTM trained checkpoints.
      activation: Activation function of the inner states.  Default: `tanh`.
      layer_norm: If `True`, layer normalization will be applied.
      norm_gain: float, The layer normalization gain initial value. If
        `layer_norm` has been set to `False`, this argument will be ignored.
      norm_shift: float, The layer normalization shift initial value. If
        `layer_norm` has been set to `False`, this argument will be ignored.
      reuse: (optional) Python boolean describing whether to reuse variables
        in an existing scope.  If not `True`, and the existing scope already has
        the given variables, an error is raised.

      When restoring from CudnnLSTM-trained checkpoints, must use
      CudnnCompatibleLSTMCell instead.
    """
    super(LayerNormLSTMCell, self).__init__(_reuse=reuse)

    self._num_units = num_units
    self._use_peepholes = use_peepholes
    self._cell_clip = cell_clip
    self._initializer = initializer
    self._num_proj = num_proj
    self._proj_clip = proj_clip
    self._forget_bias = forget_bias
    self._activation = activation or math_ops.tanh
    self._layer_norm = layer_norm
    self._norm_gain = norm_gain
    self._norm_shift = norm_shift

    if num_proj:
      self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj))
      self._output_size = num_proj
    else:
      self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units))
      self._output_size = num_units

  @property
  def state_size(self):
    return self._state_size

  @property
  def output_size(self):
    return self._output_size

  def _linear(self,
              args,
              output_size,
              bias,
              bias_initializer=None,
              kernel_initializer=None,
              layer_norm=False):
    """Linear map: sum_i(args[i] * W[i]), where W[i] is a Variable.

    Args:
      args: a 2D Tensor or a list of 2D, batch x n, Tensors.
      output_size: int, second dimension of W[i].
      bias: boolean, whether to add a bias term or not.
      bias_initializer: starting value to initialize the bias
        (default is all zeros).
      kernel_initializer: starting value to initialize the weight.
      layer_norm: boolean, whether to apply layer normalization.


    Returns:
      A 2D Tensor with shape [batch x output_size] taking value
      sum_i(args[i] * W[i]), where each W[i] is a newly created Variable.

    Raises:
      ValueError: if some of the arguments has unspecified or wrong shape.
    """
    if args is None or (nest.is_sequence(args) and not args):
      raise ValueError("`args` must be specified")
    if not nest.is_sequence(args):
      args = [args]

    # Calculate the total size of arguments on dimension 1.
    total_arg_size = 0
    shapes = [a.get_shape() for a in args]
    for shape in shapes:
      if shape.ndims != 2:
        raise ValueError("linear is expecting 2D arguments: %s" % shapes)
      if shape[1].value is None:
        raise ValueError("linear expects shape[1] to be provided for shape %s, "
                         "but saw %s" % (shape, shape[1]))
      else:
        total_arg_size += shape[1].value

    dtype = [a.dtype for a in args][0]

    # Now the computation.
    scope = vs.get_variable_scope()
    with vs.variable_scope(scope) as outer_scope:
      weights = vs.get_variable(
          "kernel", [total_arg_size, output_size],
          dtype=dtype,
          initializer=kernel_initializer)
      if len(args) == 1:
        res = math_ops.matmul(args[0], weights)
      else:
        res = math_ops.matmul(array_ops.concat(args, 1), weights)
      if not bias:
        return res
      with vs.variable_scope(outer_scope) as inner_scope:
        inner_scope.set_partitioner(None)
        if bias_initializer is None:
          bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
        biases = vs.get_variable(
            "bias", [output_size], dtype=dtype, initializer=bias_initializer)

    if not layer_norm:
      res = nn_ops.bias_add(res, biases)

    return res

  def call(self, inputs, state):
    """Run one step of LSTM.

    Args:
      inputs: input Tensor, 2D, batch x num_units.
      state: this must be a tuple of state Tensors,
       both `2-D`, with column sizes `c_state` and
        `m_state`.

    Returns:
      A tuple containing:

      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
        LSTM after reading `inputs` when previous state was `state`.
        Here output_dim is:
           num_proj if num_proj was set,
           num_units otherwise.
      - Tensor(s) representing the new state of LSTM after reading `inputs` when
        the previous state was `state`.  Same type and shape(s) as `state`.

    Raises:
      ValueError: If input size cannot be inferred from inputs via
        static shape inference.
    """
    sigmoid = math_ops.sigmoid

    (c_prev, m_prev) = state

    dtype = inputs.dtype
    input_size = inputs.get_shape().with_rank(2)[1]
    if input_size.value is None:
      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
    scope = vs.get_variable_scope()
    with vs.variable_scope(scope, initializer=self._initializer) as unit_scope:

      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
      lstm_matrix = self._linear(
          [inputs, m_prev],
          4 * self._num_units,
          bias=True,
          bias_initializer=None,
          layer_norm=self._layer_norm)
      i, j, f, o = array_ops.split(
          value=lstm_matrix, num_or_size_splits=4, axis=1)

      if self._layer_norm:
        i = _norm(self._norm_gain, self._norm_shift, i, "input")
        j = _norm(self._norm_gain, self._norm_shift, j, "transform")
        f = _norm(self._norm_gain, self._norm_shift, f, "forget")
        o = _norm(self._norm_gain, self._norm_shift, o, "output")

      # Diagonal connections
      if self._use_peepholes:
        with vs.variable_scope(unit_scope):
          w_f_diag = vs.get_variable(
              "w_f_diag", shape=[self._num_units], dtype=dtype)
          w_i_diag = vs.get_variable(
              "w_i_diag", shape=[self._num_units], dtype=dtype)
          w_o_diag = vs.get_variable(
              "w_o_diag", shape=[self._num_units], dtype=dtype)

      if self._use_peepholes:
        c = (
            sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
            sigmoid(i + w_i_diag * c_prev) * self._activation(j))
      else:
        c = (
            sigmoid(f + self._forget_bias) * c_prev +
            sigmoid(i) * self._activation(j))

      if self._layer_norm:
        c = _norm(self._norm_gain, self._norm_shift, c, "state")

      if self._cell_clip is not None:
        # pylint: disable=invalid-unary-operand-type
        c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
        # pylint: enable=invalid-unary-operand-type
      if self._use_peepholes:
        m = sigmoid(o + w_o_diag * c) * self._activation(c)
      else:
        m = sigmoid(o) * self._activation(c)

      if self._num_proj is not None:
        with vs.variable_scope("projection"):
          m = self._linear(m, self._num_proj, bias=False)

        if self._proj_clip is not None:
          # pylint: disable=invalid-unary-operand-type
          m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
          # pylint: enable=invalid-unary-operand-type

    new_state = (rnn_cell_impl.LSTMStateTuple(c, m))
    return m, new_state


class SRUCell(rnn_cell_impl.LayerRNNCell):
  """SRU, Simple Recurrent Unit.

     Implementation based on
     Training RNNs as Fast as CNNs (cf. https://arxiv.org/abs/1709.02755).

     This variation of RNN cell is characterized by the simplified data
     dependence
     between hidden states of two consecutive time steps. Traditionally, hidden
     states from a cell at time step t-1 needs to be multiplied with a matrix
     W_hh before being fed into the ensuing cell at time step t.
     This flavor of RNN replaces the matrix multiplication between h_{t-1}
     and W_hh with a pointwise multiplication, resulting in performance
     gain.

  Args:
    num_units: int, The number of units in the SRU cell.
    activation: Nonlinearity to use.  Default: `tanh`.
    reuse: (optional) Python boolean describing whether to reuse variables
      in an existing scope.  If not `True`, and the existing scope already has
      the given variables, an error is raised.
    name: (optional) String, the name of the layer. Layers with the same name
      will share weights, but to avoid mistakes we require reuse=True in such
      cases.
  """

  def __init__(self, num_units, activation=None, reuse=None, name=None):
    super(SRUCell, self).__init__(_reuse=reuse, name=name)
    self._num_units = num_units
    self._activation = activation or math_ops.tanh

    # Restrict inputs to be 2-dimensional matrices
    self.input_spec = base_layer.InputSpec(ndim=2)

  @property
  def state_size(self):
    return self._num_units

  @property
  def output_size(self):
    return self._num_units

  def build(self, inputs_shape):
    if inputs_shape[1].value is None:
      raise ValueError(
          "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)

    input_depth = inputs_shape[1].value

    # pylint: disable=protected-access
    self._kernel = self.add_variable(
        rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
        shape=[input_depth, 4 * self._num_units])
    # pylint: enable=protected-access
    self._bias = self.add_variable(
        rnn_cell_impl._BIAS_VARIABLE_NAME,  # pylint: disable=protected-access
        shape=[2 * self._num_units],
        initializer=init_ops.constant_initializer(0.0, dtype=self.dtype))

    self._built = True

  def call(self, inputs, state):
    """Simple recurrent unit (SRU) with num_units cells."""

    U = math_ops.matmul(inputs, self._kernel)  # pylint: disable=invalid-name
    x_bar, f_intermediate, r_intermediate, x_tx = array_ops.split(
        value=U, num_or_size_splits=4, axis=1)

    f_r = math_ops.sigmoid(
        nn_ops.bias_add(
            array_ops.concat([f_intermediate, r_intermediate], 1), self._bias))
    f, r = array_ops.split(value=f_r, num_or_size_splits=2, axis=1)

    c = f * state + (1.0 - f) * x_bar
    h = r * self._activation(c) + (1.0 - r) * x_tx

    return h, c


class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
  """Weight normalized LSTM Cell. Adapted from `rnn_cell_impl.LSTMCell`.

    The weight-norm implementation is based on:
    https://arxiv.org/abs/1602.07868
    Tim Salimans, Diederik P. Kingma.
    Weight Normalization: A Simple Reparameterization to Accelerate
    Training of Deep Neural Networks

    The default LSTM implementation based on:
    http://www.bioinf.jku.at/publications/older/2604.pdf
    S. Hochreiter and J. Schmidhuber.
    "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.

    The class uses optional peephole connections, optional cell clipping
    and an optional projection layer.

    The optional peephole implementation is based on:
    https://research.google.com/pubs/archive/43905.pdf
    Hasim Sak, Andrew Senior, and Francoise Beaufays.
    "Long short-term memory recurrent neural network architectures for
    large scale acoustic modeling." INTERSPEECH, 2014.
  """

  def __init__(self,
               num_units,
               norm=True,
               use_peepholes=False,
               cell_clip=None,
               initializer=None,
               num_proj=None,
               proj_clip=None,
               forget_bias=1,
               activation=None,
               reuse=None):
    """Initialize the parameters of a weight-normalized LSTM cell.

    Args:
      num_units: int, The number of units in the LSTM cell
      norm: If `True`, apply normalization to the weight matrices. If False,
        the result is identical to that obtained from `rnn_cell_impl.LSTMCell`
      use_peepholes: bool, set `True` to enable diagonal/peephole connections.
      cell_clip: (optional) A float value, if provided the cell state is clipped
        by this value prior to the cell output activation.
      initializer: (optional) The initializer to use for the weight matrices.
      num_proj: (optional) int, The output dimensionality for the projection
        matrices.  If None, no projection is performed.
      proj_clip: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
        provided, then the projected values are clipped elementwise to within
        `[-proj_clip, proj_clip]`.
      forget_bias: Biases of the forget gate are initialized by default to 1
        in order to reduce the scale of forgetting at the beginning of
        the training.
      activation: Activation function of the inner states.  Default: `tanh`.
      reuse: (optional) Python boolean describing whether to reuse variables
        in an existing scope.  If not `True`, and the existing scope already has
        the given variables, an error is raised.
    """
    super(WeightNormLSTMCell, self).__init__(_reuse=reuse)

    self._scope = "wn_lstm_cell"
    self._num_units = num_units
    self._norm = norm
    self._initializer = initializer
    self._use_peepholes = use_peepholes
    self._cell_clip = cell_clip
    self._num_proj = num_proj
    self._proj_clip = proj_clip
    self._activation = activation or math_ops.tanh
    self._forget_bias = forget_bias

    self._weights_variable_name = "kernel"
    self._bias_variable_name = "bias"

    if num_proj:
      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
      self._output_size = num_proj
    else:
      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
      self._output_size = num_units

  @property
  def state_size(self):
    return self._state_size

  @property
  def output_size(self):
    return self._output_size

  def _normalize(self, weight, name):
    """Apply weight normalization.

    Args:
      weight: a 2D tensor with known number of columns.
      name: string, variable name for the normalizer.
    Returns:
      A tensor with the same shape as `weight`.
    """

    output_size = weight.get_shape().as_list()[1]
    g = vs.get_variable(name, [output_size], dtype=weight.dtype)
    return nn_impl.l2_normalize(weight, axis=0) * g

  def _linear(self,
              args,
              output_size,
              norm,
              bias,
              bias_initializer=None,
              kernel_initializer=None):
    """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.

    Args:
      args: a 2D Tensor or a list of 2D, batch x n, Tensors.
      output_size: int, second dimension of W[i].
      norm: bool, whether to normalize the weights.
      bias: boolean, whether to add a bias term or not.
      bias_initializer: starting value to initialize the bias
        (default is all zeros).
      kernel_initializer: starting value to initialize the weight.

    Returns:
      A 2D Tensor with shape [batch x output_size] equal to
      sum_i(args[i] * W[i]), where W[i]s are newly created matrices.

    Raises:
      ValueError: if some of the arguments has unspecified or wrong shape.
    """
    if args is None or (nest.is_sequence(args) and not args):
      raise ValueError("`args` must be specified")
    if not nest.is_sequence(args):
      args = [args]

    # Calculate the total size of arguments on dimension 1.
    total_arg_size = 0
    shapes = [a.get_shape() for a in args]
    for shape in shapes:
      if shape.ndims != 2:
        raise ValueError("linear is expecting 2D arguments: %s" % shapes)
      if shape[1].value is None:
        raise ValueError("linear expects shape[1] to be provided for shape %s, "
                         "but saw %s" % (shape, shape[1]))
      else:
        total_arg_size += shape[1].value

    dtype = [a.dtype for a in args][0]

    # Now the computation.
    scope = vs.get_variable_scope()
    with vs.variable_scope(scope) as outer_scope:
      weights = vs.get_variable(
          self._weights_variable_name, [total_arg_size, output_size],
          dtype=dtype,
          initializer=kernel_initializer)
      if norm:
        wn = []
        st = 0
        with ops.control_dependencies(None):
          for i in range(len(args)):
            en = st + shapes[i][1].value
            wn.append(
                self._normalize(weights[st:en, :], name="norm_{}".format(i)))
            st = en

          weights = array_ops.concat(wn, axis=0)

      if len(args) == 1:
        res = math_ops.matmul(args[0], weights)
      else:
        res = math_ops.matmul(array_ops.concat(args, 1), weights)
      if not bias:
        return res

      with vs.variable_scope(outer_scope) as inner_scope:
        inner_scope.set_partitioner(None)
        if bias_initializer is None:
          bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)

        biases = vs.get_variable(
            self._bias_variable_name, [output_size],
            dtype=dtype,
            initializer=bias_initializer)

      return nn_ops.bias_add(res, biases)

  def call(self, inputs, state):
    """Run one step of LSTM.

    Args:
      inputs: input Tensor, 2D, batch x num_units.
      state: A tuple of state Tensors, both `2-D`, with column sizes
       `c_state` and `m_state`.

    Returns:
      A tuple containing:

      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
        LSTM after reading `inputs` when previous state was `state`.
        Here output_dim is:
           num_proj if num_proj was set,
           num_units otherwise.
      - Tensor(s) representing the new state of LSTM after reading `inputs` when
        the previous state was `state`.  Same type and shape(s) as `state`.

    Raises:
      ValueError: If input size cannot be inferred from inputs via
        static shape inference.
    """
    dtype = inputs.dtype
    num_units = self._num_units
    sigmoid = math_ops.sigmoid
    c, h = state

    input_size = inputs.get_shape().with_rank(2)[1]
    if input_size.value is None:
      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")

    with vs.variable_scope(self._scope, initializer=self._initializer):

      concat = self._linear(
          [inputs, h], 4 * num_units, norm=self._norm, bias=True)

      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
      i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)

      if self._use_peepholes:
        w_f_diag = vs.get_variable("w_f_diag", shape=[num_units], dtype=dtype)
        w_i_diag = vs.get_variable("w_i_diag", shape=[num_units], dtype=dtype)
        w_o_diag = vs.get_variable("w_o_diag", shape=[num_units], dtype=dtype)

        new_c = (
            c * sigmoid(f + self._forget_bias + w_f_diag * c) +
            sigmoid(i + w_i_diag * c) * self._activation(j))
      else:
        new_c = (
            c * sigmoid(f + self._forget_bias) +
            sigmoid(i) * self._activation(j))

      if self._cell_clip is not None:
        # pylint: disable=invalid-unary-operand-type
        new_c = clip_ops.clip_by_value(new_c, -self._cell_clip, self._cell_clip)
        # pylint: enable=invalid-unary-operand-type
      if self._use_peepholes:
        new_h = sigmoid(o + w_o_diag * new_c) * self._activation(new_c)
      else:
        new_h = sigmoid(o) * self._activation(new_c)

      if self._num_proj is not None:
        with vs.variable_scope("projection"):
          new_h = self._linear(
              new_h, self._num_proj, norm=self._norm, bias=False)

        if self._proj_clip is not None:
          # pylint: disable=invalid-unary-operand-type
          new_h = clip_ops.clip_by_value(new_h, -self._proj_clip,
                                         self._proj_clip)
          # pylint: enable=invalid-unary-operand-type

      new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
      return new_h, new_state


class IndRNNCell(rnn_cell_impl.LayerRNNCell):
  """Independently Recurrent Neural Network (IndRNN) cell
    (cf. https://arxiv.org/abs/1803.04831).

  Args:
    num_units: int, The number of units in the RNN cell.
    activation: Nonlinearity to use.  Default: `tanh`.
    reuse: (optional) Python boolean describing whether to reuse variables
     in an existing scope.  If not `True`, and the existing scope already has
     the given variables, an error is raised.
    name: String, the name of the layer. Layers with the same name will
      share weights, but to avoid mistakes we require reuse=True in such
      cases.
    dtype: Default dtype of the layer (default of `None` means use the type
      of the first input). Required when `build` is called before `call`.
  """

  def __init__(self,
               num_units,
               activation=None,
               reuse=None,
               name=None,
               dtype=None):
    super(IndRNNCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)

    # Inputs must be 2-dimensional.
    self.input_spec = base_layer.InputSpec(ndim=2)

    self._num_units = num_units
    self._activation = activation or math_ops.tanh

  @property
  def state_size(self):
    return self._num_units

  @property
  def output_size(self):
    return self._num_units

  def build(self, inputs_shape):
    if inputs_shape[1].value is None:
      raise ValueError(
          "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)

    input_depth = inputs_shape[1].value
    # pylint: disable=protected-access
    self._kernel_w = self.add_variable(
        "%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
        shape=[input_depth, self._num_units])
    self._kernel_u = self.add_variable(
        "%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
        shape=[1, self._num_units],
        initializer=init_ops.random_uniform_initializer(
            minval=-1, maxval=1, dtype=self.dtype))
    self._bias = self.add_variable(
        rnn_cell_impl._BIAS_VARIABLE_NAME,
        shape=[self._num_units],
        initializer=init_ops.zeros_initializer(dtype=self.dtype))
    # pylint: enable=protected-access

    self.built = True

  def call(self, inputs, state):
    """IndRNN: output = new_state = act(W * input + u * state + B)."""

    gate_inputs = math_ops.matmul(inputs, self._kernel_w) + (
        state * self._kernel_u)
    gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
    output = self._activation(gate_inputs)
    return output, output


class IndyGRUCell(rnn_cell_impl.LayerRNNCell):
  r"""Independently Gated Recurrent Unit cell.

  Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to GRUCell,
  yet with the \(U_r\), \(U_z\), and \(U\) matrices in equations 5, 6, and
  8 of http://arxiv.org/abs/1406.1078 respectively replaced by diagonal
  matrices, i.e. a Hadamard product with a single vector:

    $$r_j = \sigma\left([\mathbf W_r\mathbf x]_j +
      [\mathbf u_r\circ \mathbf h_{(t-1)}]_j\right)$$
    $$z_j = \sigma\left([\mathbf W_z\mathbf x]_j +
      [\mathbf u_z\circ \mathbf h_{(t-1)}]_j\right)$$
    $$\tilde{h}^{(t)}_j = \phi\left([\mathbf W \mathbf x]_j +
      [\mathbf u \circ \mathbf r \circ \mathbf h_{(t-1)}]_j\right)$$

  where \(\circ\) denotes the Hadamard operator. This means that each IndyGRU
  node sees only its own state, as opposed to seeing all states in the same
  layer.

  TODO(gonnet): Write a paper describing this and add a reference here.

  Args:
    num_units: int, The number of units in the GRU cell.
    activation: Nonlinearity to use.  Default: `tanh`.
    reuse: (optional) Python boolean describing whether to reuse variables
     in an existing scope.  If not `True`, and the existing scope already has
     the given variables, an error is raised.
    kernel_initializer: (optional) The initializer to use for the weight
      matrices applied to the input.
    bias_initializer: (optional) The initializer to use for the bias.
    name: String, the name of the layer. Layers with the same name will
      share weights, but to avoid mistakes we require reuse=True in such
      cases.
    dtype: Default dtype of the layer (default of `None` means use the type
      of the first input). Required when `build` is called before `call`.
  """

  def __init__(self,
               num_units,
               activation=None,
               reuse=None,
               kernel_initializer=None,
               bias_initializer=None,
               name=None,
               dtype=None):
    super(IndyGRUCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)

    # Inputs must be 2-dimensional.
    self.input_spec = base_layer.InputSpec(ndim=2)

    self._num_units = num_units
    self._activation = activation or math_ops.tanh
    self._kernel_initializer = kernel_initializer
    self._bias_initializer = bias_initializer

  @property
  def state_size(self):
    return self._num_units

  @property
  def output_size(self):
    return self._num_units

  def build(self, inputs_shape):
    if inputs_shape[1].value is None:
      raise ValueError(
          "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)

    input_depth = inputs_shape[1].value
    # pylint: disable=protected-access
    self._gate_kernel_w = self.add_variable(
        "gates/%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
        shape=[input_depth, 2 * self._num_units],
        initializer=self._kernel_initializer)
    self._gate_kernel_u = self.add_variable(
        "gates/%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
        shape=[1, 2 * self._num_units],
        initializer=init_ops.random_uniform_initializer(
            minval=-1, maxval=1, dtype=self.dtype))
    self._gate_bias = self.add_variable(
        "gates/%s" % rnn_cell_impl._BIAS_VARIABLE_NAME,
        shape=[2 * self._num_units],
        initializer=(self._bias_initializer
                     if self._bias_initializer is not None else
                     init_ops.constant_initializer(1.0, dtype=self.dtype)))
    self._candidate_kernel_w = self.add_variable(
        "candidate/%s" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
        shape=[input_depth, self._num_units],
        initializer=self._kernel_initializer)
    self._candidate_kernel_u = self.add_variable(
        "candidate/%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
        shape=[1, self._num_units],
        initializer=init_ops.random_uniform_initializer(
            minval=-1, maxval=1, dtype=self.dtype))
    self._candidate_bias = self.add_variable(
        "candidate/%s" % rnn_cell_impl._BIAS_VARIABLE_NAME,
        shape=[self._num_units],
        initializer=(self._bias_initializer
                     if self._bias_initializer is not None else
                     init_ops.zeros_initializer(dtype=self.dtype)))
    # pylint: enable=protected-access

    self.built = True

  def call(self, inputs, state):
    """Gated recurrent unit (GRU) with nunits cells."""

    gate_inputs = math_ops.matmul(inputs, self._gate_kernel_w) + (
        gen_array_ops.tile(state, [1, 2]) * self._gate_kernel_u)
    gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias)

    value = math_ops.sigmoid(gate_inputs)
    r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)

    r_state = r * state

    candidate = math_ops.matmul(inputs, self._candidate_kernel_w) + (
        r_state * self._candidate_kernel_u)
    candidate = nn_ops.bias_add(candidate, self._candidate_bias)

    c = self._activation(candidate)
    new_h = u * state + (1 - u) * c
    return new_h, new_h


class IndyLSTMCell(rnn_cell_impl.LayerRNNCell):
  r"""Basic IndyLSTM recurrent network cell.

  Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to
  BasicLSTMCell, yet with the \(U_f\), \(U_i\), \(U_o\) and \(U_c\)
  matrices in
  https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate
  replaced by diagonal matrices, i.e. a Hadamard product with a single vector:

    $$f_t = \sigma_g\left(W_f x_t + u_f \circ h_{t-1} + b_f\right)$$
    $$i_t = \sigma_g\left(W_i x_t + u_i \circ h_{t-1} + b_i\right)$$
    $$o_t = \sigma_g\left(W_o x_t + u_o \circ h_{t-1} + b_o\right)$$
    $$c_t = f_t \circ c_{t-1} +
            i_t \circ \sigma_c\left(W_c x_t + u_c \circ h_{t-1} + b_c\right)$$

  where \(\circ\) denotes the Hadamard operator. This means that each IndyLSTM
  node sees only its own state \(h\) and \(c\), as opposed to seeing all
  states in the same layer.

  We add forget_bias (default: 1) to the biases of the forget gate in order to
  reduce the scale of forgetting in the beginning of the training.

  It does not allow cell clipping, a projection layer, and does not
  use peep-hole connections: it is the basic baseline.

  For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
  that follows.

  TODO(gonnet): Write a paper describing this and add a reference here.
  """

  def __init__(self,
               num_units,
               forget_bias=1.0,
               activation=None,
               reuse=None,
               kernel_initializer=None,
               bias_initializer=None,
               name=None,
               dtype=None):
    """Initialize the IndyLSTM cell.

    Args:
      num_units: int, The number of units in the LSTM cell.
      forget_bias: float, The bias added to forget gates (see above).
        Must set to `0.0` manually when restoring from CudnnLSTM-trained
        checkpoints.
      activation: Activation function of the inner states.  Default: `tanh`.
      reuse: (optional) Python boolean describing whether to reuse variables
        in an existing scope.  If not `True`, and the existing scope already has
        the given variables, an error is raised.
      kernel_initializer: (optional) The initializer to use for the weight
        matrix applied to the inputs.
      bias_initializer: (optional) The initializer to use for the bias.
      name: String, the name of the layer. Layers with the same name will
        share weights, but to avoid mistakes we require reuse=True in such
        cases.
      dtype: Default dtype of the layer (default of `None` means use the type
        of the first input). Required when `build` is called before `call`.
    """
    super(IndyLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)

    # Inputs must be 2-dimensional.
    self.input_spec = base_layer.InputSpec(ndim=2)

    self._num_units = num_units
    self._forget_bias = forget_bias
    self._activation = activation or math_ops.tanh
    self._kernel_initializer = kernel_initializer
    self._bias_initializer = bias_initializer

  @property
  def state_size(self):
    return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)

  @property
  def output_size(self):
    return self._num_units

  def build(self, inputs_shape):
    if inputs_shape[1].value is None:
      raise ValueError(
          "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)

    input_depth = inputs_shape[1].value
    # pylint: disable=protected-access
    self._kernel_w = self.add_variable(
        "%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
        shape=[input_depth, 4 * self._num_units],
        initializer=self._kernel_initializer)
    self._kernel_u = self.add_variable(
        "%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
        shape=[1, 4 * self._num_units],
        initializer=init_ops.random_uniform_initializer(
            minval=-1, maxval=1, dtype=self.dtype))
    self._bias = self.add_variable(
        rnn_cell_impl._BIAS_VARIABLE_NAME,
        shape=[4 * self._num_units],
        initializer=(self._bias_initializer
                     if self._bias_initializer is not None else
                     init_ops.zeros_initializer(dtype=self.dtype)))
    # pylint: enable=protected-access

    self.built = True

  def call(self, inputs, state):
    """Independent Long short-term memory cell (IndyLSTM).

    Args:
      inputs: `2-D` tensor with shape `[batch_size, input_size]`.
      state: An `LSTMStateTuple` of state tensors, each shaped
        `[batch_size, num_units]`.

    Returns:
      A pair containing the new hidden state, and the new state (a
        `LSTMStateTuple`).
    """
    sigmoid = math_ops.sigmoid
    one = constant_op.constant(1, dtype=dtypes.int32)
    c, h = state

    gate_inputs = math_ops.matmul(inputs, self._kernel_w)
    gate_inputs += gen_array_ops.tile(h, [1, 4]) * self._kernel_u
    gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)

    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
    i, j, f, o = array_ops.split(
        value=gate_inputs, num_or_size_splits=4, axis=one)

    forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
    # Note that using `add` and `multiply` instead of `+` and `*` gives a
    # performance improvement. So using those at the cost of readability.
    add = math_ops.add
    multiply = math_ops.multiply
    new_c = add(
        multiply(c, sigmoid(add(f, forget_bias_tensor))),
        multiply(sigmoid(i), self._activation(j)))
    new_h = multiply(self._activation(new_c), sigmoid(o))

    new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
    return new_h, new_state
